Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection input and model functions for serving/inference."""
from typing import Dict, Mapping, Text
import tensorflow as tf
from official.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as cfg
from official.projects.deepmac_maskrcnn.modeling import maskrcnn_model
from official.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn
from official.vision.ops import box_ops
from official.vision.serving import detection
def reverse_input_box_transformation(boxes, image_info):
"""Reverse the Mask R-CNN model's input boxes tranformation.
Args:
boxes: A [batch_size, num_boxes, 4] float tensor of boxes in normalized
coordinates.
image_info: a 2D `Tensor` that encodes the information of the image and the
applied preprocessing. It is in the format of
[[original_height, original_width], [desired_height, desired_width],
[y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
desired_width] is the actual scaled image size, and [y_scale, x_scale] is
the scaling factor, which is the ratio of
scaled dimension / original dimension.
Returns:
boxes: Same shape as input `boxes` but in the absolute coordinate space of
the preprocessed image.
"""
# Reversing sequence from Detection_module.serve when
# output_normalized_coordinates=true
scale = image_info[:, 2:3, :]
scale = tf.tile(scale, [1, 1, 2])
boxes = boxes * scale
height_width = image_info[:, 0:1, :]
return box_ops.denormalize_boxes(boxes, height_width)
class DetectionModule(detection.DetectionModule):
"""Detection Module."""
def _build_model(self):
if self._batch_size is None:
ValueError("batch_size can't be None for detection models")
if self.params.task.model.detection_generator.nms_version != 'batched':
ValueError('Only batched_nms is supported.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
if isinstance(self.params.task.model, cfg.DeepMaskHeadRCNN):
model = deep_mask_head_rcnn.build_maskrcnn(
input_specs=input_specs, model_config=self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self.params.task.model)))
return model
@tf.function
def inference_for_tflite_image_and_boxes(
self, images: tf.Tensor, boxes: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""A tf-function for serve_image_and_boxes.
Args:
images: A [batch_size, height, width, channels] float tensor.
boxes: A [batch_size, num_boxes, 4] float tensor containing boxes
normalized to the input image.
Returns:
result: A dict containing:
'detection_masks': A [batch_size, num_boxes, mask_height, mask_width]
float tensor containing per-pixel mask probabilities.
"""
if not isinstance(self.model, maskrcnn_model.DeepMaskRCNNModel):
raise ValueError(
('Can only use image and boxes input for DeepMaskRCNNModel, '
'Found {}'.format(type(self.model))))
return self.serve_image_and_boxes(images, boxes)
def serve_image_and_boxes(self, images: tf.Tensor, boxes: tf.Tensor):
"""Function used to export a model that consumes and image and boxes.
The model predicts the class-agnostic masks at the given box locations.
Args:
images: A [batch_size, height, width, channels] float tensor.
boxes: A [batch_size, num_boxes, 4] float tensor containing boxes
normalized to the input image.
Returns:
result: A dict containing:
'detection_masks': A [batch_size, num_boxes, mask_height, mask_width]
float tensor containing per-pixel mask probabilities.
"""
images, _, image_info = self.preprocess(images)
boxes = reverse_input_box_transformation(boxes, image_info)
result = self.model.call_images_and_boxes(images, boxes)
return result
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
signatures = {}
if 'image_and_boxes_tensor' in function_keys:
def_name = function_keys['image_and_boxes_tensor']
image_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels],
dtype=tf.uint8)
boxes_signature = tf.TensorSpec(shape=[self._batch_size, None, 4],
dtype=tf.float32)
tf_function = self.inference_for_tflite_image_and_boxes
signatures[def_name] = tf_function.get_concrete_function(
image_signature, boxes_signature)
function_keys.pop('image_and_boxes_tensor', None)
parent_signatures = super(DetectionModule, self).get_inference_signatures(
function_keys)
signatures.update(parent_signatures)
return signatures
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for image detection export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.core import exp_factory
from official.projects.deepmac_maskrcnn.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name, image_size=(640, 640)):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.use_batched_nms = True
detection_module = detection.DetectionModule(
params, batch_size=1, input_image_size=list(image_size))
return detection_module
def _export_from_module(self, module, input_type, save_directory):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type, batch_size, image_size):
"""Get dummy input for the given input type."""
h, w = image_size
if input_type == 'image_tensor':
return tf.zeros((batch_size, h, w, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue() for b in range(batch_size)]
elif input_type == 'tf_example':
image_tensor = tf.zeros((h, w, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example for b in range(batch_size)]
@parameterized.parameters(
('image_tensor', 'deep_mask_head_rcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'deep_mask_head_rcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'deep_mask_head_rcnn_resnetfpn_coco', [640, 640]),
)
def test_export(self, input_type, experiment_name, image_size):
self.skipTest('a')
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name, image_size)
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
expected_outputs = module.model(
images=processed_images,
image_shape=image_shape,
anchor_boxes=anchor_boxes,
training=False)
outputs = detection_fn(tf.constant(images))
self.assertAllClose(outputs['num_detections'].numpy(),
expected_outputs['num_detections'].numpy())
@parameterized.parameters(
('deep_mask_head_rcnn_resnetfpn_coco', [640, 640], 1),
('deep_mask_head_rcnn_resnetfpn_coco', [640, 640], 5),
('deep_mask_head_rcnn_spinenet_coco', [640, 384], 3),
('deep_mask_head_rcnn_spinenet_coco', [640, 384], 9),
)
def test_export_image_and_boxes(self, experiment_name, image_size, num_boxes):
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name)
self._export_from_module(module, 'image_and_boxes_tensor', tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
detection_fn = imported.signatures['serving_default']
images = self._get_dummy_input(
'image_tensor', batch_size=1, image_size=image_size)
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros(image_size + [3], dtype=tf.uint8))
image_shape = image_info[1, :]
image_shape = image_shape[tf.newaxis]
processed_images = processed_images[tf.newaxis]
image_info = image_info[tf.newaxis]
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
boxes = np.zeros((1, num_boxes, 4), dtype=np.float32)
boxes[:, :, [2, 3]] = 1.0
boxes = tf.constant(boxes)
denormalized_boxes = detection.reverse_input_box_transformation(
boxes, image_info)
expected_outputs = module.model.call_images_and_boxes(
images=processed_images, boxes=denormalized_boxes)
outputs = detection_fn(images=tf.constant(images), boxes=boxes)
self.assertAllClose(outputs['detection_masks'].numpy(),
expected_outputs['detection_masks'].numpy(),
rtol=1e-3, atol=1e-3)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Deepmac model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
CONFIG_FILE_PATH = XX
export_saved_model --export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--config_file=${CONFIG_FILE_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.deepmac_maskrcnn.serving import detection
from official.projects.deepmac_maskrcnn.tasks import deep_mask_head_rcnn # pylint: disable=unused-import
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('experiment', 'deep_mask_head_rcnn_resnetfpn_coco',
'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer('batch_size', None, 'The batch size.')
flags.DEFINE_string('input_type', 'image_tensor',
('One of `image_tensor`, `image_bytes`, `tf_example` '
'or `image_and_boxes_tensor`.'))
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_module = detection.DetectionModule(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
num_channels=3)
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_module=export_module,
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mask R-CNN variant with support for deep mask heads."""
import tensorflow as tf
from official.core import task_factory
from official.projects.deepmac_maskrcnn.configs import deep_mask_head_rcnn as deep_mask_head_rcnn_config
from official.projects.deepmac_maskrcnn.modeling import maskrcnn_model as deep_maskrcnn_model
from official.projects.deepmac_maskrcnn.modeling.heads import instance_heads as deep_instance_heads
from official.vision.modeling import backbones
from official.vision.modeling.decoders import factory as decoder_factory
from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import instance_heads
from official.vision.modeling.layers import detection_generator
from official.vision.modeling.layers import mask_sampler
from official.vision.modeling.layers import roi_aligner
from official.vision.modeling.layers import roi_generator
from official.vision.modeling.layers import roi_sampler
from official.vision.tasks import maskrcnn
# Taken from modeling/factory.py
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
l2_regularizer: tf.keras.regularizers.Regularizer = None): # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Mask R-CNN model."""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
rpn_head_config = model_config.rpn_head
roi_generator_config = model_config.roi_generator
roi_sampler_config = model_config.roi_sampler
roi_aligner_config = model_config.roi_aligner
detection_head_config = model_config.detection_head
generator_config = model_config.detection_generator
num_anchors_per_location = (
len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
rpn_head = dense_prediction_heads.RPNHead(
min_level=model_config.min_level,
max_level=model_config.max_level,
num_anchors_per_location=num_anchors_per_location,
num_convs=rpn_head_config.num_convs,
num_filters=rpn_head_config.num_filters,
use_separable_conv=rpn_head_config.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
detection_head = instance_heads.DetectionHead(
num_classes=model_config.num_classes,
num_convs=detection_head_config.num_convs,
num_filters=detection_head_config.num_filters,
use_separable_conv=detection_head_config.use_separable_conv,
num_fcs=detection_head_config.num_fcs,
fc_dims=detection_head_config.fc_dims,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
roi_generator_obj = roi_generator.MultilevelROIGenerator(
pre_nms_top_k=roi_generator_config.pre_nms_top_k,
pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
pre_nms_min_size_threshold=(
roi_generator_config.pre_nms_min_size_threshold),
nms_iou_threshold=roi_generator_config.nms_iou_threshold,
num_proposals=roi_generator_config.num_proposals,
test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
test_pre_nms_score_threshold=(
roi_generator_config.test_pre_nms_score_threshold),
test_pre_nms_min_size_threshold=(
roi_generator_config.test_pre_nms_min_size_threshold),
test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
test_num_proposals=roi_generator_config.test_num_proposals,
use_batched_nms=roi_generator_config.use_batched_nms)
roi_sampler_obj = roi_sampler.ROISampler(
mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
num_sampled_rois=roi_sampler_config.num_sampled_rois,
foreground_fraction=roi_sampler_config.foreground_fraction,
foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
background_iou_high_threshold=(
roi_sampler_config.background_iou_high_threshold),
background_iou_low_threshold=(
roi_sampler_config.background_iou_low_threshold))
roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=roi_aligner_config.crop_size,
sample_offset=roi_aligner_config.sample_offset)
detection_generator_obj = detection_generator.DetectionGenerator(
apply_nms=True,
pre_nms_top_k=generator_config.pre_nms_top_k,
pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
nms_iou_threshold=generator_config.nms_iou_threshold,
max_num_detections=generator_config.max_num_detections,
nms_version=generator_config.nms_version)
if model_config.include_mask:
mask_head = deep_instance_heads.DeepMaskHead(
num_classes=model_config.num_classes,
upsample_factor=model_config.mask_head.upsample_factor,
num_convs=model_config.mask_head.num_convs,
num_filters=model_config.mask_head.num_filters,
use_separable_conv=model_config.mask_head.use_separable_conv,
activation=model_config.norm_activation.activation,
norm_momentum=model_config.norm_activation.norm_momentum,
norm_epsilon=model_config.norm_activation.norm_epsilon,
kernel_regularizer=l2_regularizer,
class_agnostic=model_config.mask_head.class_agnostic,
convnet_variant=model_config.mask_head.convnet_variant)
mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=(
model_config.mask_roi_aligner.crop_size *
model_config.mask_head.upsample_factor),
num_sampled_masks=model_config.mask_sampler.num_sampled_masks)
mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
crop_size=model_config.mask_roi_aligner.crop_size,
sample_offset=model_config.mask_roi_aligner.sample_offset)
else:
mask_head = None
mask_sampler_obj = None
mask_roi_aligner_obj = None
model = deep_maskrcnn_model.DeepMaskRCNNModel(
backbone=backbone,
decoder=decoder,
rpn_head=rpn_head,
detection_head=detection_head,
roi_generator=roi_generator_obj,
roi_sampler=roi_sampler_obj,
roi_aligner=roi_aligner_obj,
detection_generator=detection_generator_obj,
mask_head=mask_head,
mask_sampler=mask_sampler_obj,
mask_roi_aligner=mask_roi_aligner_obj,
use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks)
return model
@task_factory.register_task_cls(deep_mask_head_rcnn_config.DeepMaskHeadRCNNTask)
class DeepMaskHeadRCNNTask(maskrcnn.MaskRCNNTask):
"""Mask R-CNN with support for deep mask heads."""
def build_model(self):
"""Build Mask R-CNN model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = build_maskrcnn(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
if self.task_config.freeze_backbone:
model.backbone.trainable = False
return model
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
from absl import logging
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.projects.deepmac_maskrcnn.common import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
logging.info('Training with task %s', task)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# End-to-End Object Detection with Transformers (DETR)
[![DETR](https://img.shields.io/badge/DETR-arXiv.2005.12872-B3181B?)](https://arxiv.org/abs/2005.12872).
TensorFlow 2 implementation of End-to-End Object Detection with Transformers
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
distributed by Google. The dataset is made available by third parties.
Please review the terms and conditions made available by the third parties
before using the data.
## Scripts:
You can find the scripts to reproduce the following experiments in
detr/experiments.
## DETR [COCO](https://cocodataset.org) ([ImageNet](https://www.image-net.org) pretrained)
| Model | Resolution | Batch size | Epochs | Decay@ | Params (M) | Box AP | Dashboard | Checkpoint | Experiment |
| --------- | :--------: | ----------:| ------:| -----: | ---------: | -----: | --------: | ---------: | ---------: |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | [tensorboard](https://tensorboard.dev/experiment/o2IEZnniRYu6pqViBeopIg/#scalars) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_300.tar.gz) | detr_r50_300epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0| [tensorboard](https://tensorboard.dev/experiment/YFMDKpESR4yjocPh5HgfRw/) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_500.tar.gz) | detr_r50_500epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | paper | NA | NA |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0 | paper | NA | NA |
| DETR-DC5-ResNet-50 | 1333x1333 |64|500| 400 |41 | 43.3 | paper | NA | NA |
## Need contribution:
* Add DC5 support and update experiment table.
## Citing TensorFlow Model Garden
If you find this codebase helpful in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DETR configurations."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
tfds_name: str = ''
tfds_split: str = 'train'
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: common.DataDecoder = common.DataDecoder()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
@dataclasses.dataclass
class Losses(hyperparams.Config):
class_offset: int = 0
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass
class Detr(hyperparams.Config):
"""Detr model definations."""
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
backbone_endpoint_name: str = '5'
@dataclasses.dataclass
class DetrTask(cfg.TaskConfig):
model: Detr = Detr()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False
COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
num_train_data = COCO_TRAIN_EXAMPLES
num_steps_per_epoch = num_train_data // train_batch_size
train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
steps_per_loop=10000,
summary_interval=10000,
checkpoint_interval=10000,
validation_interval=10000,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfrecord')
def detr_coco_tfrecord() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=Detr(
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfds')
def detr_coco_tfds() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(class_offset=1),
train_data=DataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detr."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr.configs import detr as exp_cfg
from official.projects.detr.dataloaders import coco
class DetrTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('detr_coco',))
def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
import dataclasses
from typing import Optional, Tuple
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
@dataclasses.dataclass
class COCODataConfig(cfg.DataConfig):
"""Data config for COCO."""
output_size: Tuple[int, int] = (1333, 1333)
max_num_boxes: int = 100
resize_scales: Tuple[int, ...] = (
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class COCODataLoader():
"""A class to load dataset for COCO detection task."""
def __init__(self, params: COCODataConfig):
self._params = params
def preprocess(self, inputs):
"""Preprocess COCO for DETR."""
image = inputs['image']
boxes = inputs['objects']['bbox']
classes = inputs['objects']['label'] + 1
is_crowd = inputs['objects']['is_crowd']
image = preprocess_ops.normalize_image(image)
if self._params.is_training:
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform(
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
w = tf.random.uniform(
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(
self._params.resize_scales,
dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
else:
scales = tf.constant([self._params.resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._params.output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._params.output_size[0], self._params.output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._params.max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._params.max_num_boxes)
}
if not self._params.is_training:
labels.update({
'id':
inputs['image/id'],
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(
is_crowd, self._params.max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
gt_boxes, self._params.max_num_boxes),
})
return image, labels
def _transform_and_batch_fn(
self,
dataset,
input_context: Optional[tf.distribute.InputContext] = None):
"""Preprocess and batch."""
dataset = dataset.map(
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._params.global_batch_size
) if input_context else self._params.global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._params.drop_remainder)
return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=None,
transform_and_batch_fn=self._transform_and_batch_fn)
return reader.read(input_context)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.dataloaders.coco."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from official.projects.detr.dataloaders import coco
def _gen_fn():
h = np.random.randint(0, 300)
w = np.random.randint(0, 300)
num_boxes = np.random.randint(0, 50)
return {
'image': np.ones(shape=(h, w, 3), dtype=np.uint8),
'image/id': np.random.randint(0, 100),
'image/filename': 'test',
'objects': {
'is_crowd': np.ones(shape=(num_boxes), dtype=np.bool),
'bbox': np.ones(shape=(num_boxes, 4), dtype=np.float32),
'label': np.ones(shape=(num_boxes), dtype=np.int64),
'id': np.ones(shape=(num_boxes), dtype=np.int64),
'area': np.ones(shape=(num_boxes), dtype=np.int64),
}
}
class CocoDataloaderTest(tf.test.TestCase, parameterized.TestCase):
def test_load_dataset(self):
output_size = 1280
max_num_boxes = 100
batch_size = 2
data_config = coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=batch_size,
output_size=(output_size, output_size),
max_num_boxes=max_num_boxes,
)
num_examples = 10
def as_dataset(self, *args, **kwargs):
del args
del kwargs
return tf.data.Dataset.from_generator(
lambda: (_gen_fn() for i in range(num_examples)),
output_types=self.info.features.dtype,
output_shapes=self.info.features.shape,
)
with tfds.testing.mock_data(num_examples=num_examples,
as_dataset_fn=as_dataset):
dataset = coco.COCODataLoader(data_config).load()
dataset_iter = iter(dataset)
images, labels = next(dataset_iter)
self.assertEqual(images.shape, (batch_size, output_size, output_size, 3))
self.assertEqual(labels['classes'].shape, (batch_size, max_num_boxes))
self.assertEqual(labels['boxes'].shape, (batch_size, max_num_boxes, 4))
self.assertEqual(labels['id'].shape, (batch_size,))
self.assertEqual(
labels['image_info'].shape, (batch_size, 4, 2))
self.assertEqual(labels['is_crowd'].shape, (batch_size, max_num_boxes))
@parameterized.named_parameters(
('training', True),
('validation', False))
def test_preprocess(self, is_training):
output_size = 1280
max_num_boxes = 100
batch_size = 2
data_config = coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=is_training,
global_batch_size=batch_size,
output_size=(output_size, output_size),
max_num_boxes=max_num_boxes,
)
dl = coco.COCODataLoader(data_config)
inputs = _gen_fn()
image, label = dl.preprocess(inputs)
self.assertEqual(image.shape, (output_size, output_size, 3))
self.assertEqual(label['classes'].shape, (max_num_boxes))
self.assertEqual(label['boxes'].shape, (max_num_boxes, 4))
if not is_training:
self.assertDTypeEqual(label['id'], int)
self.assertEqual(
label['image_info'].shape, (4, 2))
self.assertEqual(label['is_crowd'].shape, (max_num_boxes))
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from typing import Tuple
import tensorflow as tf
from official.vision.dataloaders import parser
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
RESIZE_SCALES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
class_offset: int = 0,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._class_offset = class_offset
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes'] + self._class_offset
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform([],
384,
tf.math.minimum(shape[0], 600),
dtype=tf.int32)
w = tf.random.uniform([],
384,
tf.math.minimum(shape[1], 600),
dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(self._resize_scales, dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image and its size.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
scales = tf.constant([self._resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
labels.update({
'id':
int(data['source_id']),
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(is_crowd,
self._max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(gt_boxes,
self._max_num_boxes),
})
return image, labels
#!/bin/bash
python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400,trainer.optimizer_config.learning_rate.stepwise.boundaries="[369600]"
#!/bin/bash
python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements End-to-End Object Detection with Transformers.
Model paper: https://arxiv.org/abs/2005.12872
This module does not support Keras de/serialization. Please use
tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serializaiton.
"""
import math
from typing import Any, List
import tensorflow as tf
from official.modeling import tf_utils
from official.projects.detr.modeling import transformer
def position_embedding_sine(attention_mask,
num_pos_features=256,
temperature=10000.,
normalize=True,
scale=2 * math.pi):
"""Sine-based positional embeddings for 2D images.
Args:
attention_mask: a `bool` Tensor specifying the size of the input image to
the Transformer and which elements are padded, of size [batch_size,
height, width]
num_pos_features: a `int` specifying the number of positional features,
should be equal to the hidden size of the Transformer network
temperature: a `float` specifying the temperature of the positional
embedding. Any type that is converted to a `float` can also be accepted.
normalize: a `bool` determining whether the positional embeddings should be
normalized between [0, scale] before application of the sine and cos
functions.
scale: a `float` if normalize is True specifying the scale embeddings before
application of the embedding function.
Returns:
embeddings: a `float` tensor of the same shape as input_tensor specifying
the positional embeddings based on sine features.
"""
if num_pos_features % 2 != 0:
raise ValueError(
"Number of embedding features (num_pos_features) must be even when "
"column and row embeddings are concatenated.")
num_pos_features = num_pos_features // 2
# Produce row and column embeddings based on total size of the image
# <tf.float>[batch_size, height, width]
attention_mask = tf.cast(attention_mask, tf.float32)
row_embedding = tf.cumsum(attention_mask, 1)
col_embedding = tf.cumsum(attention_mask, 2)
if normalize:
eps = 1e-6
row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale
col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale
dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype)
dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features)
# Creates positional embeddings for each row and column position
# <tf.float>[batch_size, height, width, num_pos_features]
pos_row = tf.expand_dims(row_embedding, -1) / dim_t
pos_col = tf.expand_dims(col_embedding, -1) / dim_t
pos_row = tf.stack(
[tf.sin(pos_row[:, :, :, 0::2]),
tf.cos(pos_row[:, :, :, 1::2])], axis=4)
pos_col = tf.stack(
[tf.sin(pos_col[:, :, :, 0::2]),
tf.cos(pos_col[:, :, :, 1::2])], axis=4)
# final_shape = pos_row.shape.as_list()[:3] + [-1]
final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1]
pos_row = tf.reshape(pos_row, final_shape)
pos_col = tf.reshape(pos_col, final_shape)
output = tf.concat([pos_row, pos_col], -1)
embeddings = tf.cast(output, tf.float32)
return embeddings
class DETR(tf.keras.Model):
"""DETR model with Keras.
DETR consists of backbone, query embedding, DETRTransformer,
class and box heads.
"""
def __init__(self,
backbone,
backbone_endpoint_name,
num_queries,
hidden_size,
num_classes,
num_encoder_layers=6,
num_decoder_layers=6,
dropout_rate=0.1,
**kwargs):
super().__init__(**kwargs)
self._num_queries = num_queries
self._hidden_size = hidden_size
self._num_classes = num_classes
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
self._dropout_rate = dropout_rate
if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.")
self._backbone = backbone
self._backbone_endpoint_name = backbone_endpoint_name
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
self._hidden_size, 1, name="detr/conv2d")
self._build_detection_decoder()
super().build(input_shape)
def _build_detection_decoder(self):
"""Builds detection decoder."""
self._transformer = DETRTransformer(
num_encoder_layers=self._num_encoder_layers,
num_decoder_layers=self._num_decoder_layers,
dropout_rate=self._dropout_rate)
self._query_embeddings = self.add_weight(
"detr/query_embeddings",
shape=[self._num_queries, self._hidden_size],
initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.),
dtype=tf.float32)
sqrt_k = math.sqrt(1.0 / self._hidden_size)
self._class_embed = tf.keras.layers.Dense(
self._num_classes,
kernel_initializer=tf.keras.initializers.RandomUniform(-sqrt_k, sqrt_k),
name="detr/cls_dense")
self._bbox_embed = [
tf.keras.layers.Dense(
self._hidden_size, activation="relu",
kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_0"),
tf.keras.layers.Dense(
self._hidden_size, activation="relu",
kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_1"),
tf.keras.layers.Dense(
4, kernel_initializer=tf.keras.initializers.RandomUniform(
-sqrt_k, sqrt_k),
name="detr/box_dense_2")]
self._sigmoid = tf.keras.layers.Activation("sigmoid")
@property
def backbone(self) -> tf.keras.Model:
return self._backbone
def get_config(self):
return {
"backbone": self._backbone,
"backbone_endpoint_name": self._backbone_endpoint_name,
"num_queries": self._num_queries,
"hidden_size": self._hidden_size,
"num_classes": self._num_classes,
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"dropout_rate": self._dropout_rate,
}
@classmethod
def from_config(cls, config):
return cls(**config)
def _generate_image_mask(self, inputs: tf.Tensor,
target_shape: tf.Tensor) -> tf.Tensor:
"""Generates image mask from input image."""
mask = tf.expand_dims(
tf.cast(tf.not_equal(tf.reduce_sum(inputs, axis=-1), 0), inputs.dtype),
axis=-1)
mask = tf.image.resize(
mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return mask
def call(self, inputs: tf.Tensor) -> List[Any]:
batch_size = tf.shape(inputs)[0]
features = self._backbone(inputs)[self._backbone_endpoint_name]
shape = tf.shape(features)
mask = self._generate_image_mask(inputs, shape[1: 3])
pos_embed = position_embedding_sine(
mask[:, :, :, 0], num_pos_features=self._hidden_size)
pos_embed = tf.reshape(pos_embed, [batch_size, -1, self._hidden_size])
features = tf.reshape(
self._input_proj(features), [batch_size, -1, self._hidden_size])
mask = tf.reshape(mask, [batch_size, -1])
decoded_list = self._transformer({
"inputs":
features,
"targets":
tf.tile(
tf.expand_dims(self._query_embeddings, axis=0),
(batch_size, 1, 1)),
"pos_embed": pos_embed,
"mask": mask,
})
out_list = []
for decoded in decoded_list:
decoded = tf.stack(decoded)
output_class = self._class_embed(decoded)
box_out = decoded
for layer in self._bbox_embed:
box_out = layer(box_out)
output_coord = self._sigmoid(box_out)
out = {"cls_outputs": output_class, "box_outputs": output_coord}
out_list.append(out)
return out_list
class DETRTransformer(tf.keras.layers.Layer):
"""Encoder and Decoder of DETR."""
def __init__(self, num_encoder_layers=6, num_decoder_layers=6,
dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)
self._dropout_rate = dropout_rate
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
def build(self, input_shape=None):
if self._num_encoder_layers > 0:
self._encoder = transformer.TransformerEncoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_encoder_layers)
else:
self._encoder = None
self._decoder = transformer.TransformerDecoder(
attention_dropout_rate=self._dropout_rate,
dropout_rate=self._dropout_rate,
intermediate_dropout=self._dropout_rate,
norm_first=False,
num_layers=self._num_decoder_layers)
super().build(input_shape)
def get_config(self):
return {
"num_encoder_layers": self._num_encoder_layers,
"num_decoder_layers": self._num_decoder_layers,
"dropout_rate": self._dropout_rate,
}
def call(self, inputs):
sources = inputs["inputs"]
targets = inputs["targets"]
pos_embed = inputs["pos_embed"]
mask = inputs["mask"]
input_shape = tf_utils.get_shape_list(sources)
source_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, input_shape[1], 1])
if self._encoder is not None:
memory = self._encoder(
sources, attention_mask=source_attention_mask, pos_embed=pos_embed)
else:
memory = sources
target_shape = tf_utils.get_shape_list(targets)
cross_attention_mask = tf.tile(
tf.expand_dims(mask, axis=1), [1, target_shape[1], 1])
target_shape = tf.shape(targets)
decoded = self._decoder(
tf.zeros_like(targets),
memory,
# TODO(b/199545430): self_attention_mask could be set to None when this
# bug is resolved. Passing ones for now.
self_attention_mask=tf.ones(
(target_shape[0], target_shape[1], target_shape[1])),
cross_attention_mask=cross_attention_mask,
return_all_decoder_outputs=True,
input_pos_embed=targets,
memory_pos_embed=pos_embed)
return decoded
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.detr."""
import tensorflow as tf
from official.projects.detr.modeling import detr
from official.vision.modeling.backbones import resnet
class DetrTest(tf.test.TestCase):
def test_forward(self):
num_queries = 10
hidden_size = 128
num_classes = 10
image_size = 640
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
model = detr.DETR(backbone, backbone_endpoint_name, num_queries,
hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs:
self.assertAllEqual(
tf.shape(out['cls_outputs']), (batch_size, num_queries, num_classes))
self.assertAllEqual(
tf.shape(out['box_outputs']), (batch_size, num_queries, 4))
def test_get_from_config_detr_transformer(self):
config = {
'num_encoder_layers': 1,
'num_decoder_layers': 2,
'dropout_rate': 0.5,
}
detr_model = detr.DETRTransformer.from_config(config)
retrieved_config = detr_model.get_config()
self.assertEqual(config, retrieved_config)
def test_get_from_config_detr(self):
config = {
'backbone': resnet.ResNet(50, bn_trainable=False),
'backbone_endpoint_name': '5',
'num_queries': 2,
'hidden_size': 4,
'num_classes': 10,
'num_encoder_layers': 4,
'num_decoder_layers': 5,
'dropout_rate': 0.5,
}
detr_model = detr.DETR.from_config(config)
retrieved_config = detr_model.get_config()
self.assertEqual(config, retrieved_config)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Specialized Transformers for DETR.
the position embeddings are added to the query and key for every self- and
cross-attention layer.
"""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models
class TransformerEncoder(tf.keras.layers.Layer):
"""Transformer encoder.
Transformer encoder is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.encoder_layers = []
for i in range(self.num_layers):
self.encoder_layers.append(
TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size,
inner_activation=self._activation,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
inner_dropout=self._intermediate_dropout,
attention_initializer=tf_utils.clone_initializer(
models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
super(TransformerEncoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, encoder_inputs, attention_mask=None, pos_embed=None):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
pos_embed: Position embedding to add to every encoder layer.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx](
[encoder_inputs, attention_mask, pos_embed])
output_tensor = encoder_inputs
output_tensor = self.output_normalization(output_tensor)
return output_tensor
class TransformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. The only difference: position embedding is
added to the query and key of self-attention.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
num_attention_heads,
inner_dim,
inner_activation,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super().__init__(**kwargs)
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(TransformerEncoderBlock, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes":
self._attention_axes,
}
base_config = super(TransformerEncoderBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors. `input tensor` as the single
sequence of embeddings. [`input tensor`, `attention mask`] to have the
additional attention mask. [`input tensor`, `attention mask`, `query
embed`] to have an additional position embedding to add.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
input_tensor, attention_mask, pos_embed = inputs
key_value = None
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor + pos_embed,
key=key_value + pos_embed,
value=key_value,
attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + layer_output
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(layer_output + attention_output)
class TransformerDecoder(tf.keras.layers.Layer):
"""Transformer decoder.
Like the encoder, the decoder is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def __init__(self,
num_layers=6,
num_attention_heads=8,
intermediate_size=2048,
activation="relu",
dropout_rate=0.0,
attention_dropout_rate=0.0,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.0,
**kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self._intermediate_size = intermediate_size
self._activation = activation
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
def build(self, input_shape):
"""Implements build() for the layer."""
self.decoder_layers = []
for i in range(self.num_layers):
self.decoder_layers.append(
TransformerDecoderBlock(
num_attention_heads=self.num_attention_heads,
intermediate_size=self._intermediate_size,
intermediate_activation=self._activation,
dropout_rate=self._dropout_rate,
attention_dropout_rate=self._attention_dropout_rate,
use_bias=self._use_bias,
norm_first=self._norm_first,
norm_epsilon=self._norm_epsilon,
intermediate_dropout=self._intermediate_dropout,
attention_initializer=tf_utils.clone_initializer(
models.seq2seq_transformer.attention_initializer(
input_shape[2])),
name=("layer_%d" % i)))
self.output_normalization = tf.keras.layers.LayerNormalization(
epsilon=self._norm_epsilon, dtype="float32")
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
config = {
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self._intermediate_size,
"activation": self._activation,
"dropout_rate": self._dropout_rate,
"attention_dropout_rate": self._attention_dropout_rate,
"use_bias": self._use_bias,
"norm_first": self._norm_first,
"norm_epsilon": self._norm_epsilon,
"intermediate_dropout": self._intermediate_dropout
}
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
target,
memory,
self_attention_mask=None,
cross_attention_mask=None,
cache=None,
decode_loop_step=None,
return_all_decoder_outputs=False,
input_pos_embed=None,
memory_pos_embed=None):
"""Return the output of the decoder layer stacks.
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape `(batch_size, i, value_channels)`},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
return_all_decoder_outputs: Return all decoder layer outputs. Note that
the outputs are layer normed. This is useful when introducing per layer
auxiliary loss.
input_pos_embed: A tensor that is added to the query and key of the
self-attention layer.
memory_pos_embed: A tensor that is added to the query and key of the
cross-attention layer.
Returns:
Output of decoder.
float32 tensor with shape `(batch_size, target_length, hidden_size`).
"""
output_tensor = target
decoder_outputs = []
for layer_idx in range(self.num_layers):
transformer_inputs = [
output_tensor, memory, cross_attention_mask, self_attention_mask,
input_pos_embed, memory_pos_embed
]
# Gets the cache for decoding.
if cache is None:
output_tensor, _ = self.decoder_layers[layer_idx](transformer_inputs)
else:
cache_layer_idx = str(layer_idx)
output_tensor, cache[cache_layer_idx] = self.decoder_layers[layer_idx](
transformer_inputs,
cache=cache[cache_layer_idx],
decode_loop_step=decode_loop_step)
if return_all_decoder_outputs:
decoder_outputs.append(self.output_normalization(output_tensor))
if return_all_decoder_outputs:
return decoder_outputs
else:
return self.output_normalization(output_tensor)
class TransformerDecoderBlock(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
intermediate_dropout=0.0,
attention_initializer=None,
**kwargs):
"""Initialize a Transformer decoder block.
Args:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output
dropout.
attention_dropout_rate: Dropout probability for within the attention
layer.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super().__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
intermediate_activation)
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._intermediate_dropout = intermediate_dropout
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._cross_attention_cls = layers.attention.MultiHeadAttention
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape.as_list()) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
if hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Self attention.
self.self_attention = layers.attention.CachedAttention(
num_heads=self.num_attention_heads,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32"))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
name="attention/encdec",
**common_kwargs)
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32"))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype="float32")
super().build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
tf.keras.activations.serialize(self.intermediate_activation),
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"intermediate_dropout":
self._intermediate_dropout,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer)
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
input_tensor, memory, attention_mask, self_attention_mask, input_pos_embed, memory_pos_embed = inputs
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention(
query=input_tensor + input_pos_embed,
key=input_tensor + input_pos_embed,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
if self._norm_first:
self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict(
query=self_attention_output + input_pos_embed,
key=memory + memory_pos_embed,
value=memory,
attention_mask=attention_mask)
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
if self._norm_first:
attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
intermediate_output = self._intermediate_dropout_layer(intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer."""
import tensorflow as tf
from official.projects.detr.modeling import transformer
class TransformerTest(tf.test.TestCase):
def test_transformer_encoder_block(self):
batch_size = 2
sequence_length = 100
feature_size = 256
num_attention_heads = 2
inner_dim = 256
inner_activation = 'relu'
model = transformer.TransformerEncoderBlock(num_attention_heads, inner_dim,
inner_activation)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, sequence_length),
dtype=tf.int64)
pos_embed = tf.ones((batch_size, sequence_length, feature_size))
out = model([input_tensor, attention_mask, pos_embed])
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_encoder_block_get_config(self):
num_attention_heads = 2
inner_dim = 256
inner_activation = 'relu'
model = transformer.TransformerEncoderBlock(num_attention_heads, inner_dim,
inner_activation)
config = model.get_config()
expected_config = {
'name': 'transformer_encoder_block',
'trainable': True,
'dtype': 'float32',
'num_attention_heads': 2,
'inner_dim': 256,
'inner_activation': 'relu',
'output_dropout': 0.0,
'attention_dropout': 0.0,
'output_range': None,
'kernel_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None}
},
'bias_initializer': {
'class_name': 'Zeros',
'config': {}
},
'kernel_regularizer': None,
'bias_regularizer': None,
'activity_regularizer': None,
'kernel_constraint': None,
'bias_constraint': None,
'use_bias': True,
'norm_first': False,
'norm_epsilon': 1e-12,
'inner_dropout': 0.0,
'attention_initializer': {
'class_name': 'GlorotUniform',
'config': {'seed': None}
},
'attention_axes': None}
self.assertAllEqual(expected_config, config)
def test_transformer_encoder(self):
batch_size = 2
sequence_length = 100
feature_size = 256
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerEncoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, sequence_length),
dtype=tf.int64)
pos_embed = tf.ones((batch_size, sequence_length, feature_size))
out = model(input_tensor, attention_mask, pos_embed)
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_encoder_get_config(self):
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerEncoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
config = model.get_config()
expected_config = {
'name': 'transformer_encoder',
'trainable': True,
'dtype': 'float32',
'num_layers': 2,
'num_attention_heads': 2,
'intermediate_size': 256,
'activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'use_bias': False,
'norm_first': True,
'norm_epsilon': 1e-06,
'intermediate_dropout': 0.0
}
self.assertAllEqual(expected_config, config)
def test_transformer_decoder_block(self):
batch_size = 2
sequence_length = 100
memory_length = 200
feature_size = 256
num_attention_heads = 2
intermediate_size = 256
intermediate_activation = 'relu'
model = transformer.TransformerDecoderBlock(num_attention_heads,
intermediate_size,
intermediate_activation)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
memory = tf.ones((batch_size, memory_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, memory_length),
dtype=tf.int64)
self_attention_mask = tf.ones(
(batch_size, sequence_length, sequence_length), dtype=tf.int64)
input_pos_embed = tf.ones((batch_size, sequence_length, feature_size))
memory_pos_embed = tf.ones((batch_size, memory_length, feature_size))
out, _ = model([
input_tensor, memory, attention_mask, self_attention_mask,
input_pos_embed, memory_pos_embed
])
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_decoder_block_get_config(self):
num_attention_heads = 2
intermediate_size = 256
intermediate_activation = 'relu'
model = transformer.TransformerDecoderBlock(num_attention_heads,
intermediate_size,
intermediate_activation)
config = model.get_config()
expected_config = {
'name': 'transformer_decoder_block',
'trainable': True,
'dtype': 'float32',
'num_attention_heads': 2,
'intermediate_size': 256,
'intermediate_activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'kernel_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None
}
},
'bias_initializer': {
'class_name': 'Zeros',
'config': {}
},
'kernel_regularizer': None,
'bias_regularizer': None,
'activity_regularizer': None,
'kernel_constraint': None,
'bias_constraint': None,
'use_bias': True,
'norm_first': False,
'norm_epsilon': 1e-12,
'intermediate_dropout': 0.0,
'attention_initializer': {
'class_name': 'GlorotUniform',
'config': {
'seed': None
}
}
}
self.assertAllEqual(expected_config, config)
def test_transformer_decoder(self):
batch_size = 2
sequence_length = 100
memory_length = 200
feature_size = 256
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerDecoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
input_tensor = tf.ones((batch_size, sequence_length, feature_size))
memory = tf.ones((batch_size, memory_length, feature_size))
attention_mask = tf.ones((batch_size, sequence_length, memory_length),
dtype=tf.int64)
self_attention_mask = tf.ones(
(batch_size, sequence_length, sequence_length), dtype=tf.int64)
input_pos_embed = tf.ones((batch_size, sequence_length, feature_size))
memory_pos_embed = tf.ones((batch_size, memory_length, feature_size))
outs = model(
input_tensor,
memory,
self_attention_mask,
attention_mask,
return_all_decoder_outputs=True,
input_pos_embed=input_pos_embed,
memory_pos_embed=memory_pos_embed)
self.assertLen(outs, 2) # intermeidate decoded outputs.
for out in outs:
self.assertAllEqual(
tf.shape(out), (batch_size, sequence_length, feature_size))
def test_transformer_decoder_get_config(self):
num_layers = 2
num_attention_heads = 2
intermediate_size = 256
model = transformer.TransformerDecoder(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size)
config = model.get_config()
expected_config = {
'name': 'transformer_decoder',
'trainable': True,
'dtype': 'float32',
'num_layers': 2,
'num_attention_heads': 2,
'intermediate_size': 256,
'activation': 'relu',
'dropout_rate': 0.0,
'attention_dropout_rate': 0.0,
'use_bias': False,
'norm_first': True,
'norm_epsilon': 1e-06,
'intermediate_dropout': 0.0
}
self.assertAllEqual(expected_config, config)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
The Linear Sum Assignment problem involves determining the minimum weight
matching for bipartite graphs. For example, this problem can be defined by
a 2D matrix C, where each element i,j determines the cost of matching worker i
with job j. The solution to the problem is a complete assignment of jobs to
workers, such that no job is assigned to more than one work and no worker is
assigned more than one job, with minimum cost.
This implementation builds off of the Hungarian
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
"""
import tensorflow as tf
from official.modeling import tf_utils
def _prepare(weights):
"""Prepare the cost matrix.
To speed up computational efficiency of the algorithm, all weights are shifted
to be non-negative. Each element is reduced by the row / column minimum. Note
that neither operation will effect the resulting solution but will provide
a better starting point for the greedy assignment. Note this corresponds to
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A prepared weights tensor of the same shape and dtype.
"""
# Since every worker needs a job and every job needs a worker, we can subtract
# the minimum from each.
weights -= tf.reduce_min(weights, axis=2, keepdims=True)
weights -= tf.reduce_min(weights, axis=1, keepdims=True)
return weights
def _greedy_assignment(adj_matrix):
"""Greedily assigns workers to jobs based on an adjaceny matrix.
Starting with an adjacency matrix representing the available connections
in the bi-partite graph, this function greedily chooses elements such
that each worker is matched to at most one job (or each job is assigned to
at most one worker). Note, if the adjacency matrix has no available values
for a particular row/column, the corresponding job/worker may go unassigned.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
Each row and column can have at most one true element. Some of the rows
and columns may not be matched.
"""
_, num_elems, _ = tf_utils.get_shape_list(adj_matrix, expected_rank=3)
adj_matrix = tf.transpose(adj_matrix, [1, 0, 2])
# Create a dynamic TensorArray containing the assignments for each worker/job
assignment = tf.TensorArray(tf.bool, num_elems)
# Store the elements assigned to each column to update each iteration
col_assigned = tf.zeros_like(adj_matrix[0, ...], dtype=tf.bool)
# Iteratively assign each row using tf.foldl. Intuitively, this is a loop
# over rows, where we incrementally assign each row.
def _assign_row(accumulator, row_adj):
# The accumulator tracks the row assignment index.
idx, assignment, col_assigned = accumulator
# Viable candidates cannot already be assigned to another job.
candidates = row_adj & (~col_assigned)
# Deterministically assign to the candidates of the highest index count.
max_candidate_idx = tf.argmax(
tf.cast(candidates, tf.int32), axis=1, output_type=tf.int32)
candidates_indicator = tf.one_hot(
max_candidate_idx,
num_elems,
on_value=True,
off_value=False,
dtype=tf.bool)
candidates_indicator &= candidates
# Make assignment to the column.
col_assigned |= candidates_indicator
assignment = assignment.write(idx, candidates_indicator)
return (idx + 1, assignment, col_assigned)
_, assignment, _ = tf.foldl(
_assign_row, adj_matrix, (0, assignment, col_assigned), back_prop=False)
assignment = assignment.stack()
assignment = tf.transpose(assignment, [1, 0, 2])
return assignment
def _find_augmenting_path(assignment, adj_matrix):
"""Finds an augmenting path given an assignment and an adjacency matrix.
The augmenting path search starts from the unassigned workers, then goes on
to find jobs (via an unassigned pairing), then back again to workers (via an
existing pairing), and so on. The path alternates between unassigned and
existing pairings. Returns the state after the search.
Note: In the state the worker and job, indices are 1-indexed so that we can
use 0 to represent unreachable nodes. State contains the following keys:
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
unassigned worker that can reach this job through a path.
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
reached immediately before this job.
- workers: A [batch_size, num_elems, 1] tensor containing the highest index
unassigned worker that can reach this worker through a path.
- workers_from_job: A [batch_size, num_elems] tensor containing the job
reached immediately before this worker.
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
unassigned job can be reached via a path.
State can be used to recover the path via backtracking.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
"""
batch_size, num_elems, _ = tf_utils.get_shape_list(
assignment, expected_rank=3)
unassigned_workers = ~tf.reduce_any(assignment, axis=2, keepdims=True)
unassigned_jobs = ~tf.reduce_any(assignment, axis=1, keepdims=True)
unassigned_pairings = tf.cast(adj_matrix & ~assignment, tf.int32)
existing_pairings = tf.cast(assignment, tf.int32)
# Initialize unassigned workers to have non-zero ids, assigned workers will
# have ids = 0.
worker_indices = tf.range(1, num_elems + 1, dtype=tf.int32)
init_workers = tf.tile(worker_indices[tf.newaxis, :, tf.newaxis],
[batch_size, 1, 1])
init_workers *= tf.cast(unassigned_workers, tf.int32)
state = {
"jobs": tf.zeros((batch_size, 1, num_elems), dtype=tf.int32),
"jobs_from_worker": tf.zeros((batch_size, num_elems), dtype=tf.int32),
"workers": init_workers,
"workers_from_job": tf.zeros((batch_size, num_elems), dtype=tf.int32)
}
def _has_active_workers(state, curr_workers):
"""Check if there are still active workers."""
del state
return tf.reduce_sum(curr_workers) > 0
def _augment_step(state, curr_workers):
"""Performs one search step."""
# Note: These steps could be potentially much faster if sparse matrices are
# supported. The unassigned_pairings and existing_pairings matrices can be
# very sparse.
# Find potential jobs using current workers.
potential_jobs = curr_workers * unassigned_pairings
curr_jobs = tf.reduce_max(potential_jobs, axis=1, keepdims=True)
curr_jobs_from_worker = 1 + tf.argmax(
potential_jobs, axis=1, output_type=tf.int32)
# Remove already accessible jobs from curr_jobs.
default_jobs = tf.zeros_like(state["jobs"], dtype=state["jobs"].dtype)
curr_jobs = tf.where(state["jobs"] > 0, default_jobs, curr_jobs)
curr_jobs_from_worker *= tf.cast(curr_jobs > 0, tf.int32)[:, 0, :]
# Find potential workers from current jobs.
potential_workers = curr_jobs * existing_pairings
curr_workers = tf.reduce_max(potential_workers, axis=2, keepdims=True)
curr_workers_from_job = 1 + tf.argmax(
potential_workers, axis=2, output_type=tf.int32)
# Remove already accessible workers from curr_workers.
default_workers = tf.zeros_like(state["workers"])
curr_workers = tf.where(
state["workers"] > 0, default_workers, curr_workers)
curr_workers_from_job *= tf.cast(curr_workers > 0, tf.int32)[:, :, 0]
# Update state so that we can backtrack later.
state = state.copy()
state["jobs"] = tf.maximum(state["jobs"], curr_jobs)
state["jobs_from_worker"] = tf.maximum(state["jobs_from_worker"],
curr_jobs_from_worker)
state["workers"] = tf.maximum(state["workers"], curr_workers)
state["workers_from_job"] = tf.maximum(state["workers_from_job"],
curr_workers_from_job)
return state, curr_workers
state, _ = tf.while_loop(
_has_active_workers,
_augment_step, (state, init_workers),
back_prop=False)
# Compute new jobs, this is useful for determnining termnination of the
# maximum bi-partite matching and initialization for backtracking.
new_jobs = (state["jobs"] > 0) & unassigned_jobs
state["new_jobs"] = new_jobs[:, 0, :]
return state
def _improve_assignment(assignment, state):
"""Improves an assignment by backtracking the augmented path using state.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
state: A dict, which represents the outcome of running an augmenting path
search on the graph given the assignment.
Returns:
A new assignment matrix of the same shape and type as assignment, where the
assignment has been updated using the augmented path found.
"""
batch_size, num_elems, _ = tf_utils.get_shape_list(assignment, 3)
# We store the current job id and iteratively backtrack using jobs_from_worker
# and workers_from_job until we reach an unassigned worker. We flip all the
# assignments on this path to discover a better overall assignment.
# Note: The indices in state are 1-indexed, where 0 represents that the
# worker / job cannot be reached.
# Obtain initial job indices based on new_jobs.
curr_job_idx = tf.argmax(
tf.cast(state["new_jobs"], tf.int32), axis=1, output_type=tf.int32)
# Track whether an example is actively being backtracked. Since we are
# operating on a batch, not all examples in the batch may be active.
active = tf.gather(state["new_jobs"], curr_job_idx, batch_dims=1)
batch_range = tf.range(0, batch_size, dtype=tf.int32)
# Flip matrix tracks which assignments we need to flip - corresponding to the
# augmenting path taken. We use an integer tensor here so that we can use
# tensor_scatter_nd_add to update the tensor, and then cast it back to bool
# after the loop.
flip_matrix = tf.zeros((batch_size, num_elems, num_elems), dtype=tf.int32)
def _has_active_backtracks(flip_matrix, active, curr_job_idx):
"""Check if there are still active workers."""
del flip_matrix, curr_job_idx
return tf.reduce_any(active)
def _backtrack_one_step(flip_matrix, active, curr_job_idx):
"""Take one step in backtracking."""
# Discover the worker that the job originated from, note that this worker
# must exist by construction.
curr_worker_idx = tf.gather(
state["jobs_from_worker"], curr_job_idx, batch_dims=1) - 1
curr_worker_idx = tf.maximum(curr_worker_idx, 0)
update_indices = tf.stack([batch_range, curr_worker_idx, curr_job_idx],
axis=1)
update_indices = tf.maximum(update_indices, 0)
flip_matrix = tf.tensor_scatter_nd_add(flip_matrix, update_indices,
tf.cast(active, tf.int32))
# Discover the (potential) job that the worker originated from.
curr_job_idx = tf.gather(
state["workers_from_job"], curr_worker_idx, batch_dims=1) - 1
# Note that jobs may not be active, and we track that here (before
# adjusting indices so that they are all >= 0 for gather).
active &= curr_job_idx >= 0
curr_job_idx = tf.maximum(curr_job_idx, 0)
update_indices = tf.stack([batch_range, curr_worker_idx, curr_job_idx],
axis=1)
update_indices = tf.maximum(update_indices, 0)
flip_matrix = tf.tensor_scatter_nd_add(flip_matrix, update_indices,
tf.cast(active, tf.int32))
return flip_matrix, active, curr_job_idx
flip_matrix, _, _ = tf.while_loop(
_has_active_backtracks,
_backtrack_one_step, (flip_matrix, active, curr_job_idx),
back_prop=False)
flip_matrix = tf.cast(flip_matrix, tf.bool)
assignment = tf.math.logical_xor(assignment, flip_matrix)
return assignment
def _maximum_bipartite_matching(adj_matrix, assignment=None):
"""Performs maximum bipartite matching using augmented paths.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
assignment: An optional bool [batch_size, num_elems, num_elems] tensor,
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A state dict representing the final augmenting path state search, and
a maximum bipartite matching assignment tensor. Note that the state outcome
can be used to compute a minimum vertex cover for the bipartite graph.
"""
if assignment is None:
assignment = _greedy_assignment(adj_matrix)
state = _find_augmenting_path(assignment, adj_matrix)
def _has_new_jobs(state, assignment):
del assignment
return tf.reduce_any(state["new_jobs"])
def _improve_assignment_and_find_new_path(state, assignment):
assignment = _improve_assignment(assignment, state)
state = _find_augmenting_path(assignment, adj_matrix)
return state, assignment
state, assignment = tf.while_loop(
_has_new_jobs,
_improve_assignment_and_find_new_path, (state, assignment),
back_prop=False)
return state, assignment
def _compute_cover(state, assignment):
"""Computes a cover for the bipartite graph.
We compute a cover using the construction provided at
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
which uses the outcome from the alternating path search.
Args:
state: A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
assignment: An optional bool [batch_size, num_elems, num_elems] tensor,
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A tuple of (workers_cover, jobs_cover) corresponding to row and column
covers for the bipartite graph. workers_cover is a boolean tensor of shape
[batch_size, num_elems, 1] and jobs_cover is a boolean tensor of shape
[batch_size, 1, num_elems].
"""
assigned_workers = tf.reduce_any(assignment, axis=2, keepdims=True)
assigned_jobs = tf.reduce_any(assignment, axis=1, keepdims=True)
reachable_workers = state["workers"] > 0
reachable_jobs = state["jobs"] > 0
workers_cover = assigned_workers & (~reachable_workers)
jobs_cover = assigned_jobs & reachable_jobs
return workers_cover, jobs_cover
def _update_weights_using_cover(workers_cover, jobs_cover, weights):
"""Updates weights for hungarian matching using a cover.
We first find the minimum uncovered weight. Then, we subtract this from all
the uncovered weights, and add it to all the doubly covered weights.
Args:
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A new weight matrix with elements adjusted by the cover.
"""
max_value = tf.reduce_max(weights)
covered = workers_cover | jobs_cover
double_covered = workers_cover & jobs_cover
uncovered_weights = tf.where(covered,
tf.ones_like(weights) * max_value, weights)
min_weight = tf.reduce_min(uncovered_weights, axis=[-2, -1], keepdims=True)
add_weight = tf.where(double_covered,
tf.ones_like(weights) * min_weight,
tf.zeros_like(weights))
sub_weight = tf.where(covered, tf.zeros_like(weights),
tf.ones_like(weights) * min_weight)
return weights + add_weight - sub_weight
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
expected_rank_dict = {}
if isinstance(expected_rank, int):
expected_rank_dict[expected_rank] = True
else:
for x in expected_rank:
expected_rank_dict[x] = True
actual_rank = len(tensor.shape)
if actual_rank not in expected_rank_dict:
raise ValueError(
"For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
"equal to the expected tensor rank `%s`" %
(name, actual_rank, str(tensor.shape), str(expected_rank)))
def hungarian_matching(weights):
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
The returned matching will always be a perfect match.
"""
batch_size, num_elems, _ = tf_utils.get_shape_list(weights, 3)
weights = _prepare(weights)
adj_matrix = tf.equal(weights, 0.)
state, assignment = _maximum_bipartite_matching(adj_matrix)
workers_cover, jobs_cover = _compute_cover(state, assignment)
def _cover_incomplete(workers_cover, jobs_cover, *args):
del args
cover_sum = (
tf.reduce_sum(tf.cast(workers_cover, tf.int32)) +
tf.reduce_sum(tf.cast(jobs_cover, tf.int32)))
return tf.less(cover_sum, batch_size * num_elems)
def _update_weights_and_match(workers_cover, jobs_cover, weights, assignment):
weights = _update_weights_using_cover(workers_cover, jobs_cover, weights)
adj_matrix = tf.equal(weights, 0.)
state, assignment = _maximum_bipartite_matching(adj_matrix, assignment)
workers_cover, jobs_cover = _compute_cover(state, assignment)
return workers_cover, jobs_cover, weights, assignment
workers_cover, jobs_cover, weights, assignment = tf.while_loop(
_cover_incomplete,
_update_weights_and_match,
(workers_cover, jobs_cover, weights, assignment),
back_prop=False)
return weights, assignment
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment