Commit ac1f5735 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 452578619
parent 52902342
# Towards End-to-End Unified Scene Text Detection and Layout Analysis
![unified detection](docs/images/task.png)
[![UnifiedDetector](https://img.shields.io/badge/UnifiedDetector-arxiv.2203.15143-green)](https://arxiv.org/abs/2203.15143)
Official TensorFlow 2 implementation of the paper `Towards End-to-End Unified
Scene Text Detection and Layout Analysis`. If you encounter any issues using the
code, you are welcome to submit them to the Issues tab or send emails directly
to us: `hiertext@google.com`.
## Installation
### Set up TensorFlow Models
```bash
# (Optional) Create and enter a virtual environment
pip3 install --user virtualenv
virtualenv -p python3 unified_detector
source ./unified_detector/bin/activate
# First clone the TensorFlow Models project:
git clone https://github.com/tensorflow/models.git
# Install the requirements of TensorFlow Models and this repo:
cd models
pip3 install -r official/requirements.txt
pip3 install -r official/projects/unified_detector/requirements.txt
# Compile the protos
# If `protoc` is not installed, please follow: https://grpc.io/docs/protoc-installation/
export PYTHONPATH=${PYTHONPATH}:${PWD}/research/
cd research/object_detection/
protoc protos/string_int_label_map.proto --python_out=.
```
### Set up Deeplab2
```bash
# Clone Deeplab2 anywhere you like
cd <somewhere>
git clone https://github.com/google-research/deeplab2.git
# Compile the protos
protoc deeplab2/*.proto --python_out=.
# Add to PYTHONPATH the directory where deeplab2 sits.
export PYTHONPATH=${PYTHONPATH}:${PWD}
```
## Running the model on some images using the provided checkpoint.
### Download the checkpoint
Model | Input Resolution | #object query | line PQ (val) | paragraph PQ (val) | line PQ (test) | paragraph PQ (test)
---------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ------------- | ------------- | ------------------ | -------------- | -------------------
Unified-Detector-Line ([ckpt](https://storage.cloud.google.com/tf_model_garden/vision/unified_detector/unified_detector_ckpt.tgz)) | 1024 | 384 | 61.04 | 52.84 | 62.20 | 53.52
### Demo on single images
```bash
# run from `models/`
python3 -m official.projects.unified_detector.run_inference \
--gin_file=official/projects/unified_detector/configs/gin_files/unified_detector_model.gin \
--ckpt_path=<path-of-the-ckpt> \
--img_file=<some-image> \
--output_path=<some-directory>/demo.jsonl \
--vis_dir=<some-directory>
```
The output will be stored in jsonl in the same hierarchical format as required
by the evaluation script of the HierText dataset. There will also be
visualizations of the word/line/paragraph boundaries. Note that, the unified
detector produces line-level masks and an affinity matrix for grouping lines
into paragraphs. For visualization purpose, we split each line mask into pixel
groups which are defined as connected components/pixels. We visualize these
groups as `words`. They are not necessarily at the word granularity, though. We
visualize lines and paragraphs as groupings of these `words` using axis-aligned
bounding boxes.
## Inference and Evaluation on the HierText dataset
### Download the HierText dataset
Clone the [HierText repo](https://github.com/google-research-datasets/hiertext)
and download the dataset. The `requirements.txt` in this folder already covers
those in the HierText repo, so there is no need to create a new virtual
environment again.
### Inference and eval
The following command will run the model on the validation set and compute the
score. Note that the test set annotation is not released yet, so only validation
set is used here for demo purposes.
#### Inference
```bash
# Run from `models/`
python3 -m official.projects.unified_detector.run_inference \
--gin_file=official/projects/unified_detector/configs/gin_files/unified_detector_model.gin \
--ckpt_path=<path-of-the-ckpt> \
--img_dir=<the-directory-containing-validation-images> \
--output_path=<some-directory>/validation_output.jsonl
```
#### Evaluation
```bash
# Run from `hiertext/`
python3 eval.py \
--gt=gt/validation.jsonl \
--result=<some-directory>/validation_output.jsonl \
--output=./validation-score.txt \
--mask_stride=1 \
--eval_lines \
--eval_paragraphs \
--num_workers=0
```
## Train new models.
First, you will need to convert the HierText dataset into TFrecords:
```bash
# Run from `models/official/projects/unified_detector/data_conversion`
CUDA_VISIBLE_DEVICES='' python3 convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
```
To train the unified detector, run the following script:
```bash
# Run from `models/`
python3 -m official.projects.unified_detector.train \
--mode=train \
--experiment=unified_detector \
--model_dir='<some path>' \
--gin_file='official/projects/unified_detector/configs/gin_files/unified_detector_train.gin' \
--gin_file='official/projects/unified_detector/configs/gin_files/unified_detector_model.gin' \
--gin_params='InputFn.input_paths = ["/path/to/tfrecords/file-prefix*"]'
```
## Citation
Please cite our [paper](https://arxiv.org/pdf/2203.15143.pdf) if you find this
work helpful:
```
@inproceedings{long2022towards,
title={Towards End-to-End Unified Scene Text Detection and Layout Analysis},
author={Long, Shangbang and Qin, Siyang and Panteleev, Dmitry and Bissacco, Alessandro and Fujii, Yasuhisa and Raptis, Michalis},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
```
# Defining the unified detector models.
# Model
## Backbone
num_slots = 384
SyncBatchNormalization.momentum = 0.95
get_max_deep_lab_backbone.num_slots = %num_slots
## Decoder
intermediate_filters = 256
num_entity_class = 3 # C + 1 (bkg) + 1 (void)
_get_decoder_head.atrous_rates = (6, 12, 18)
_get_decoder_head.pixel_space_dim = 128
_get_decoder_head.pixel_space_intermediate = %intermediate_filters
_get_decoder_head.num_classes = %num_entity_class
_get_decoder_head.aux_sem_intermediate = %intermediate_filters
_get_decoder_head.low_level = [
{'feature_key': 'res3', 'channels_project': 64,},
{'feature_key': 'res2', 'channels_project': 32,},]
_get_decoder_head.norm_fn = @SyncBatchNormalization
_get_embed_head.norm_fn = @LayerNorm
# Loss
# pq loss
alpha = 0.75
tau = 0.3
_entity_mask_loss.alpha = %alpha
_instance_discrimination_loss.tau = %tau
_paragraph_grouping_loss.tau = %tau
_paragraph_grouping_loss.loss_mode = 'balanced'
# Other Model setting
UniversalDetector.mask_threshold = 0.4
UniversalDetector.class_threshold = 0.5
UniversalDetector.filter_area = 32
universal_detection_loss_weights.loss_segmentation_word = 1e0
universal_detection_loss_weights.loss_inst_dist = 1e0
universal_detection_loss_weights.loss_mask_id = 1e-4
universal_detection_loss_weights.loss_pq = 3e0
universal_detection_loss_weights.loss_para = 1e0
# Defining the input pipeline of unified detector.
# ===== ===== Model ===== =====
# Internal import 2.
OcrTask.model_fn = @UniversalDetector
# ===== ===== Data pipeline ===== =====
InputFn.parser_fn = @UniDetectorParserFn
InputFn.dataset_type = 'tfrecord'
InputFn.batch_size = 256
# Internal import 3.
UniDetectorParserFn.output_dimension = 1024
# Simple data augmentation for now.
UniDetectorParserFn.rot90_probability = 0.0
UniDetectorParserFn.use_color_distortion = True
UniDetectorParserFn.crop_min_scale = 0.5
UniDetectorParserFn.crop_max_scale = 1.5
UniDetectorParserFn.crop_min_aspect = 0.8
UniDetectorParserFn.crop_max_aspect = 1.25
UniDetectorParserFn.max_num_instance = 384
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OCR tasks and models configurations."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
@dataclasses.dataclass
class OcrTaskConfig(cfg.TaskConfig):
train_data: cfg.DataConfig = cfg.DataConfig()
model_call_needs_labels: bool = False
@exp_factory.register_config_factory('unified_detector')
def unified_detector() -> cfg.ExperimentConfig:
"""Configurations for trainer of unified detector."""
total_train_steps = 100000
summary_interval = steps_per_loop = 200
checkpoint_interval = 2000
warmup_steps = 1000
config = cfg.ExperimentConfig(
# Input pipeline and model are configured through Gin.
task=OcrTaskConfig(train_data=cfg.DataConfig(is_training=True)),
trainer=cfg.TrainerConfig(
train_steps=total_train_steps,
steps_per_loop=steps_per_loop,
summary_interval=summary_interval,
checkpoint_interval=checkpoint_interval,
max_to_keep=1,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': [
'^((?!depthwise).)*(kernel|weights):0$',
],
'exclude_from_weight_decay': [
'(^((?!kernel).)*:0)|(depthwise_kernel)',
],
'gradient_clip_norm': 10.,
},
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 1e-3,
'decay_steps': total_train_steps - warmup_steps,
'alpha': 1e-2,
'offset': warmup_steps,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_learning_rate': 1e-5,
'warmup_steps': warmup_steps,
}
},
}),
),
)
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to convert HierText to TFExamples.
This script is only intended to run locally.
python3 data_preprocess/convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
"""
import json
import os
import random
from absl import app
from absl import flags
import tensorflow as tf
import tqdm
import utils
_GT_FILE = flags.DEFINE_string('gt_file', None, 'Path to the GT file')
_IMG_DIR = flags.DEFINE_string('img_dir', None, 'Path to the image folder.')
_OUT_FILE = flags.DEFINE_string('out_file', None, 'Path for the tfrecords.')
_NUM_SHARD = flags.DEFINE_integer(
'num_shard', 100, 'The number of shards of tfrecords.')
def main(unused_argv) -> None:
annotations = json.load(open(_GT_FILE.value))['annotations']
random.shuffle(annotations)
n_sample = len(annotations)
n_shards = _NUM_SHARD.value
n_sample_per_shard = (n_sample - 1) // n_shards + 1
for shard in tqdm.tqdm(range(n_shards)):
output_path = f'{_OUT_FILE.value}-{shard:05}-{n_shards:05}.tfrecords'
annotation_subset = annotations[
shard * n_sample_per_shard : (shard + 1) * n_sample_per_shard]
with tf.io.TFRecordWriter(output_path) as file_writer:
for annotation in annotation_subset:
img_file_path = os.path.join(_IMG_DIR.value,
f"{annotation['image_id']}.jpg")
tfexample = utils.convert_to_tfe(img_file_path, annotation)
file_writer.write(tfexample)
if __name__ == '__main__':
flags.mark_flags_as_required(['gt_file', 'img_dir', 'out_file'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to convert data to TFExamples and store in TFRecords."""
from typing import Any, Dict, List, Tuple, Union
import cv2
import numpy as np
import tensorflow as tf
def encode_image(
image_tensor: np.ndarray,
encoding_type: str = 'png') -> Union[np.ndarray, tf.Tensor]:
"""Encode image tensor into byte string."""
if encoding_type == 'jpg':
image_encoded = tf.image.encode_jpeg(tf.constant(image_tensor))
elif encoding_type == 'png':
image_encoded = tf.image.encode_png(tf.constant(image_tensor))
else:
raise ValueError('Invalid encoding type.')
if tf.executing_eagerly():
image_encoded = image_encoded.numpy()
else:
image_encoded = image_encoded.eval()
return image_encoded
def int64_feature(value: Union[int, List[int]]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value: Union[float, List[float]]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value: Union[Union[bytes, str], List[Union[bytes, str]]]
) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
for i in range(len(value)):
if not isinstance(value[i], bytes):
value[i] = value[i].encode('utf-8')
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def annotation_to_entities(annotation: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Flatten the annotation dict to a list of 'entities'."""
entities = []
for paragraph in annotation['paragraphs']:
paragraph_id = len(entities)
paragraph['type'] = 3 # 3 for paragraph
paragraph['parent_id'] = -1
entities.append(paragraph)
for line in paragraph['lines']:
line_id = len(entities)
line['type'] = 2 # 2 for line
line['parent_id'] = paragraph_id
entities.append(line)
for word in line['words']:
word['type'] = 1 # 1 for word
word['parent_id'] = line_id
entities.append(word)
return entities
def draw_entity_mask(
entities: List[Dict[str, Any]],
image_shape: Tuple[int, int, int]) -> np.ndarray:
"""Draw entity id mask.
Args:
entities: A list of entity objects. Should be output from
`annotation_to_entities`.
image_shape: The shape of the input image.
Returns:
A (H, W, 3) entity id mask of the same height/width as the image. Each pixel
(i, j, :) encodes the entity id of one pixel. Only word entities are
rendered. 0 for non-text pixels; word entity ids start from 1.
"""
instance_mask = np.zeros(image_shape, dtype=np.uint8)
for i, entity in enumerate(entities):
# only draw word masks
if entity['type'] != 1:
continue
vertices = np.array(entity['vertices'])
# the pixel value is actually 1 + position in entities
entity_id = i + 1
if entity_id >= 65536:
# As entity_id is encoded in the last two channels, it should be less than
# 256**2=65536.
raise ValueError(
(f'Entity ID overflow: {entity_id}. Currently only entity_id<65536 '
'are supported.'))
# use the last two channels to encode the entity id.
color = [0, entity_id // 256, entity_id % 256]
instance_mask = cv2.fillPoly(instance_mask,
[np.round(vertices).astype('int32')], color)
return instance_mask
def convert_to_tfe(img_file_name: str,
annotation: Dict[str, Any]) -> tf.train.Example:
"""Convert the annotation dict into a TFExample."""
img = cv2.imread(img_file_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, c = img.shape
encoded_img = encode_image(img)
entities = annotation_to_entities(annotation)
masks = draw_entity_mask(entities, img.shape)
encoded_mask = encode_image(masks)
# encode attributes
parent = []
classes = []
content_type = []
text = []
vertices = []
for entity in entities:
parent.append(entity['parent_id'])
classes.append(entity['type'])
# 0 for annotated; 8 for not annotated
content_type.append((0 if entity['legible'] else 8))
text.append(entity.get('text', ''))
v = np.array(entity['vertices'])
vertices.append(','.join(str(float(n)) for n in v.reshape(-1)))
example = tf.train.Example(
features=tf.train.Features(
feature={
# input images
'image/encoded': bytes_feature(encoded_img),
# image format
'image/format': bytes_feature('png'),
# image width
'image/width': int64_feature([w]),
# image height
'image/height': int64_feature([h]),
# image channels
'image/channels': int64_feature([c]),
# image key
'image/source_id': bytes_feature(annotation['image_id']),
# HxWx3 tensors: channel 2-3 encodes the id of the word entity.
'image/additional_channels/encoded': bytes_feature(encoded_mask),
# format of the additional channels
'image/additional_channels/format': bytes_feature('png'),
'image/object/parent': int64_feature(parent),
# word / line / paragraph / symbol / ...
'image/object/classes': int64_feature(classes),
# text / handwritten / not-annotated / ...
'image/object/content_type': int64_feature(content_type),
# string text transcription
'image/object/text': bytes_feature(text),
# comma separated coordinates, (x,y) * n
'image/object/vertices': bytes_feature(vertices),
})).SerializeToString()
return example
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
This library is adapted from:
`https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py`.
Several changes are made. They are inspired by the TIMM library:
https://github.com/rwightman/pytorch-image-models/tree/master/timm/data
Changes include:
(1) Random Erasing / Cutout is added, and separated from the random augmentation
pool (not sampled as an operation).
(2) For `posterize` and `solarize`, the arguments are changed such that the
level of corruption increases as the `magnitude` argument increases.
(3) `color`, `contrast`, `brightness`, `sharpness` are randomly enhanced or
diminished.
(4) Magnitude is randomly sampled from a normal distribution.
(5) Operations are applied with a probability.
"""
import inspect
import math
import tensorflow as tf
import tensorflow_addons.image as tfa_image
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.
def policy_v0():
"""Autoaugment policy that was used in AutoAugment Paper."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
return policy
def policy_vtest():
"""Autoaugment test policy for debugging."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy = [
[('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
]
return policy
# pylint: disable=g-long-lambda
blend = tf.function(lambda i1, i2, factor: tf.cast(
tfa_image.blend(tf.cast(i1, tf.float32), tf.cast(i2, tf.float32), factor),
tf.uint8))
# pylint: enable=g-long-lambda
def random_erase(image,
prob,
min_area=0.02,
max_area=1 / 3,
min_aspect=1 / 3,
max_aspect=10 / 3,
mode='pixel'):
"""The random erasing augmentations: https://arxiv.org/pdf/1708.04896.pdf.
This augmentation is applied after image normalization.
Args:
image: Input image after all other augmentation and normalization. It has
type tf.float32.
prob: Probability of applying the random erasing operation.
min_area: As named.
max_area: As named.
min_aspect: As named.
max_aspect: As named.
mode: How the erased area is filled. 'pixel' means white noise (uniform
dist).
Returns:
Randomly erased image.
"""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
image_area = tf.cast(image_width * image_height, tf.float32)
# Sample width, height
erase_area = tf.random.uniform([], min_area, max_area) * image_area
log_max_target_ar = tf.math.log(
tf.minimum(
tf.math.divide(
tf.math.square(tf.cast(image_width, tf.float32)), erase_area),
max_aspect))
log_min_target_ar = tf.math.log(
tf.maximum(
tf.math.divide(erase_area,
tf.math.square(tf.cast(image_height, tf.float32))),
min_aspect))
erase_aspect_ratio = tf.math.exp(
tf.random.uniform([], log_min_target_ar, log_max_target_ar))
erase_h = tf.cast(tf.math.sqrt(erase_area / erase_aspect_ratio), tf.int32)
erase_w = tf.cast(tf.math.sqrt(erase_area * erase_aspect_ratio), tf.int32)
# Sample (left, top) of the rectangle to erase
erase_left = tf.random.uniform(
shape=[], minval=0, maxval=image_width - erase_w, dtype=tf.int32)
erase_top = tf.random.uniform(
shape=[], minval=0, maxval=image_height - erase_h, dtype=tf.int32)
pad_right = image_width - erase_w - erase_left
pad_bottom = image_height - erase_h - erase_top
mask = tf.pad(
tf.zeros([erase_h, erase_w], dtype=image.dtype),
[[erase_top, pad_bottom], [erase_left, pad_right]],
constant_values=1)
mask = tf.expand_dims(mask, -1) # [H, W, 1]
if mode == 'pixel':
fill = tf.random.truncated_normal(
tf.shape(image), 0.0, 1.0, dtype=image.dtype)
else:
fill = tf.zeros(tf.shape(image), dtype=image.dtype)
should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond(should_apply_op,
lambda: mask * image + (1 - mask) * fill,
lambda: image)
return augmented_image
def solarize(image, threshold=128):
# For each pixel in the image, select the pixel
# if the value is less than the threshold.
# Otherwise, subtract 255 from the pixel.
return tf.where(image < threshold, image, 255 - image)
def solarize_add(image, addition=0, threshold=128):
# For each pixel in the image less than threshold
# we add 'addition' amount to it and then clip the
# pixel value to be between 0 and 255. The value
# of 'addition' is between -128 and 128.
added_image = tf.cast(image, tf.int64) + addition
added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
return tf.where(image < threshold, added_image, image)
def color(image, factor):
"""Equivalent of PIL Color."""
degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
return blend(degenerate, image, factor)
def contrast(image, factor):
"""Equivalent of PIL Contrast."""
degenerate = tf.image.rgb_to_grayscale(image)
# Cast before calling tf.histogram.
degenerate = tf.cast(degenerate, tf.int32)
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
return blend(degenerate, image, factor)
def brightness(image, factor):
"""Equivalent of PIL Brightness."""
degenerate = tf.zeros_like(image)
return blend(degenerate, image, factor)
def posterize(image, bits):
"""Equivalent of PIL Posterize. Smaller `bits` means larger degradation."""
shift = 8 - bits
return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
def rotate(image, degrees, replace):
"""Rotates the image by degrees either clockwise or counterclockwise.
Args:
image: An image Tensor of type uint8.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
replace: A one or three value 1D tensor to fill empty pixels caused by the
rotate operation.
Returns:
The rotated version of image.
"""
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = degrees * degrees_to_radians
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
if isinstance(replace, list) or isinstance(replace, tuple):
replace = replace[0]
image = tfa_image.rotate(image, radians, fill_value=replace)
return image
def translate_x(image, pixels, replace):
"""Equivalent of PIL Translate in X dimension."""
return tfa_image.translate_xy(image, [-pixels, 0], replace)
def translate_y(image, pixels, replace):
"""Equivalent of PIL Translate in Y dimension."""
return tfa_image.translate_xy(image, [0, -pixels], replace)
def autocontrast(image):
"""Implements Autocontrast function from PIL using TF ops.
Args:
image: A 3D uint8 tensor.
Returns:
The image after it has had autocontrast applied to it and will be of type
uint8.
"""
def scale_channel(image):
"""Scale the 2D image using the autocontrast rule."""
# A possibly cheaper version can be done using cumsum/unique_with_counts
# over the histogram values, rather than iterating over the entire image.
# to compute mins and maxes.
lo = tf.cast(tf.reduce_min(image), tf.float32)
hi = tf.cast(tf.reduce_max(image), tf.float32)
# Scale the image, making the lowest value 0 and the highest value 255.
def scale_values(im):
scale = 255.0 / (hi - lo)
offset = -lo * scale
im = tf.cast(im, tf.float32) * scale + offset
im = tf.clip_by_value(im, 0.0, 255.0)
return tf.cast(im, tf.uint8)
result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
return result
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image[:, :, 0])
s2 = scale_channel(image[:, :, 1])
s3 = scale_channel(image[:, :, 2])
image = tf.stack([s1, s2, s3], 2)
return image
def sharpness(image, factor):
"""Implements Sharpness function from PIL using TF ops."""
orig_image = image
image = tf.cast(image, tf.float32)
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13.
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
with tf.device('/cpu:0'):
# Some augmentation that uses depth-wise conv will cause crashing when
# training on GPU. See (b/156242594) for details.
degenerate = tf.nn.depthwise_conv2d(image, kernel, strides, padding='VALID')
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
# For the borders of the resulting image, fill in the values of the
# original image.
mask = tf.ones_like(degenerate)
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
# Blend the final result.
return blend(result, orig_image, factor)
def equalize(image):
"""Implements Equalize function from PIL using TF ops."""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = tf.cast(im[:, :, c], tf.int32)
# Compute the histogram of the image channel.
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
# For the purposes of computing the step, filter out the nonzeros.
nonzero = tf.where(tf.not_equal(histo, 0))
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut = (tf.cumsum(histo) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = tf.concat([[0], lut[:-1]], 0)
# Clip the counts to be in range. This is done
# in the C code for image.point.
return tf.clip_by_value(lut, 0, 255)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(
tf.equal(step, 0), lambda: im,
lambda: tf.gather(build_lut(histo, step), im))
return tf.cast(result, tf.uint8)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image, 0)
s2 = scale_channel(image, 1)
s3 = scale_channel(image, 2)
image = tf.stack([s1, s2, s3], 2)
return image
def invert(image):
"""Inverts the image pixels."""
image = tf.convert_to_tensor(image)
return 255 - image
NAME_TO_FUNC = {
'AutoContrast': autocontrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'PosterizeIncreasing': posterize,
'Solarize': solarize,
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'ColorIncreasing': color,
'Contrast': contrast,
'ContrastIncreasing': contrast,
'Brightness': brightness,
'BrightnessIncreasing': brightness,
'Sharpness': sharpness,
'SharpnessIncreasing': sharpness,
'ShearX': tfa_image.shear_x,
'ShearY': tfa_image.shear_y,
'TranslateX': translate_x,
'TranslateY': translate_y,
'Cutout': tfa_image.random_cutout,
'Hue': tf.image.adjust_hue,
}
def _randomly_negate_tensor(tensor):
"""With 50% prob turn the tensor negative."""
should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
final_tensor = tf.cond(should_flip, lambda: -tensor, lambda: tensor)
return final_tensor
def _rotate_level_to_arg(level):
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate_tensor(level)
return (level,)
def _shrink_level_to_arg(level):
"""Converts level to ratio by which we shrink the image content."""
if level == 0:
return (1.0,) # if level is zero, do not shrink the image
# Maximum shrinking ratio is 2.9.
level = 2. / (_MAX_LEVEL / level) + 0.9
return (level,)
def _enhance_level_to_arg(level):
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _enhance_increasing_level_to_arg(level):
level = (level / _MAX_LEVEL) * .9
level = 1.0 + _randomly_negate_tensor(level)
return (level,)
def _shear_level_to_arg(level):
level = (level / _MAX_LEVEL) * 0.3
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
def _translate_level_to_arg(level, translate_const):
level = level / _MAX_LEVEL * translate_const
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
def _posterize_level_to_arg(level):
return (tf.cast(level / _MAX_LEVEL * 4, tf.uint8),)
def _posterize_increase_level_to_arg(level):
return (4 - _posterize_level_to_arg(level)[0],)
def _solarize_level_to_arg(level):
return (tf.cast(level / _MAX_LEVEL * 256, tf.uint8),)
def _solarize_increase_level_to_arg(level):
return (256 - _solarize_level_to_arg(level)[0],)
def _solarize_add_level_to_arg(level):
return (tf.cast(level / _MAX_LEVEL * 110, tf.int64),)
def _cutout_arg(level, cutout_size):
pad_size = tf.cast(level / _MAX_LEVEL * cutout_size, tf.int32)
return (2 * pad_size, 2 * pad_size)
def level_to_arg(hparams):
return {
'AutoContrast':
lambda level: (),
'Equalize':
lambda level: (),
'Invert':
lambda level: (),
'Rotate':
_rotate_level_to_arg,
'Posterize':
_posterize_level_to_arg,
'PosterizeIncreasing':
_posterize_increase_level_to_arg,
'Solarize':
_solarize_level_to_arg,
'SolarizeIncreasing':
_solarize_increase_level_to_arg,
'SolarizeAdd':
_solarize_add_level_to_arg,
'Color':
_enhance_level_to_arg,
'ColorIncreasing':
_enhance_increasing_level_to_arg,
'Contrast':
_enhance_level_to_arg,
'ContrastIncreasing':
_enhance_increasing_level_to_arg,
'Brightness':
_enhance_level_to_arg,
'BrightnessIncreasing':
_enhance_increasing_level_to_arg,
'Sharpness':
_enhance_level_to_arg,
'SharpnessIncreasing':
_enhance_increasing_level_to_arg,
'ShearX':
_shear_level_to_arg,
'ShearY':
_shear_level_to_arg,
# pylint:disable=g-long-lambda
'Cutout':
lambda level: _cutout_arg(level, hparams['cutout_const']),
# pylint:disable=g-long-lambda
'TranslateX':
lambda level: _translate_level_to_arg(level, hparams['translate_const'
]),
'TranslateY':
lambda level: _translate_level_to_arg(level, hparams['translate_const'
]),
'Hue':
lambda level: ((level / _MAX_LEVEL) * 0.25,),
# pylint:enable=g-long-lambda
}
def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
args = level_to_arg(augmentation_hparams)[name](level)
# Add in replace arg if it is required for the function that is being called.
# pytype:disable=wrong-arg-types
if 'replace' in inspect.signature(func).parameters.keys(): # pylint: disable=deprecated-method
args = tuple(list(args) + [replace_value])
# pytype:enable=wrong-arg-types
return (func, prob, args)
def _apply_func_with_prob(func, image, args, prob):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple)
# Apply the function with probability `prob`.
should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
lambda: image)
return augmented_image
def select_and_apply_random_policy(policies, image):
"""Select a random policy from `policies` and apply it to `image`."""
policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
# Note that using tf.case instead of tf.conds would result in significantly
# larger graphs and would even break export for some larger policies.
for (i, policy) in enumerate(policies):
image = tf.cond(
tf.equal(i, policy_to_select),
lambda selected_policy=policy: selected_policy(image),
lambda: image)
return image
def build_and_apply_nas_policy(policies, image, augmentation_hparams):
"""Build a policy from the given policies passed in and apply to image.
Args:
policies: list of lists of tuples in the form `(func, prob, level)`, `func`
is a string name of the augmentation function, `prob` is the probability
of applying the `func` operation, `level` is the input argument for
`func`.
image: tf.Tensor that the resulting policy will be applied to.
augmentation_hparams: Hparams associated with the NAS learned policy.
Returns:
A version of image that now has data augmentation applied to it based on
the `policies` pass into the function.
"""
replace_value = [128, 128, 128]
# func is the string name of the augmentation function, prob is the
# probability of applying the operation and level is the parameter associated
# with the tf op.
# tf_policies are functions that take in an image and return an augmented
# image.
tf_policies = []
for policy in policies:
tf_policy = []
# Link string name to the correct python function and make sure the correct
# argument is passed into that function.
for policy_info in policy:
policy_info = list(policy_info) + [replace_value, augmentation_hparams]
tf_policy.append(_parse_policy_info(*policy_info))
# Now build the tf policy that will apply the augmentation procedue
# on image.
def make_final_policy(tf_policy_):
def final_policy(image_):
for func, prob, args in tf_policy_:
image_ = _apply_func_with_prob(func, image_, args, prob)
return image_
return final_policy
tf_policies.append(make_final_policy(tf_policy))
augmented_image = select_and_apply_random_policy(tf_policies, image)
return augmented_image
def distort_image_with_autoaugment(image, augmentation_name):
"""Applies the AutoAugment policy to `image`.
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
augmentation_name: The name of the AutoAugment policy to use. The available
options are `v0` and `test`. `v0` is the policy used for all of the
results in the paper and was found to achieve the best results on the COCO
dataset. `v1`, `v2` and `v3` are additional good policies found on the
COCO dataset that have slight variation in what operations were used
during the search procedure along with how many operations are applied in
parallel to a single image (2 vs 3).
Returns:
A tuple containing the augmented versions of `image`.
"""
available_policies = {'v0': policy_v0, 'test': policy_vtest}
if augmentation_name not in available_policies:
raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))
policy = available_policies[augmentation_name]()
# Hparams that will be used for AutoAugment.
augmentation_hparams = dict(cutout_const=100, translate_const=250)
return build_and_apply_nas_policy(policy, image, augmentation_hparams)
# Cutout is implemented separately.
_RAND_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'Posterize',
'Solarize',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateX',
'TranslateY',
'SolarizeAdd',
'Hue',
]
# Cutout is implemented separately.
_RAND_INCREASING_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'SolarizeAdd',
'ColorIncreasing',
'ContrastIncreasing',
'BrightnessIncreasing',
'SharpnessIncreasing',
'ShearX',
'ShearY',
'TranslateX',
'TranslateY',
'Hue',
]
# These augmentations are not suitable for detection task.
_NON_COLOR_DISTORTION_OPS = [
'Rotate',
'ShearX',
'ShearY',
'TranslateX',
'TranslateY',
]
def distort_image_with_randaugment(image,
num_layers,
magnitude,
mag_std,
inc,
prob,
color_only=False):
"""Applies the RandAugment policy to `image`.
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
Args:
image: `Tensor` of shape [height, width, 3] representing an image. The image
should have uint8 type in [0, 255].
num_layers: Integer, the number of augmentation transformations to apply
sequentially to an image. Represented as (N) in the paper. Usually best
values will be in the range [1, 3].
magnitude: Integer, shared magnitude across all augmentation operations.
Represented as (M) in the paper. Usually best values are in the range [5,
30].
mag_std: Randomness of magnitude. The magnitude will be sampled from a
normal distribution on the fly.
inc: Whether to select aug that increases as magnitude increases.
prob: Probability of any aug being applied.
color_only: Whether only apply operations that distort color and do not
change spatial layouts.
Returns:
The augmented version of `image`.
"""
replace_value = [128] * 3
augmentation_hparams = dict(cutout_const=40, translate_const=100)
available_ops = _RAND_INCREASING_TRANSFORMS if inc else _RAND_TRANSFORMS
if color_only:
available_ops = list(
filter(lambda op: op not in _NON_COLOR_DISTORTION_OPS, available_ops))
for layer_num in range(num_layers):
op_to_select = tf.random.uniform([],
maxval=len(available_ops),
dtype=tf.int32)
random_magnitude = tf.clip_by_value(
tf.random.normal([], magnitude, mag_std), 0., _MAX_LEVEL)
with tf.name_scope('randaug_layer_{}'.format(layer_num)):
for (i, op_name) in enumerate(available_ops):
func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
replace_value, augmentation_hparams)
image = tf.cond(
tf.equal(i, op_to_select),
# pylint:disable=g-long-lambda
lambda s_func=func, s_args=args: _apply_func_with_prob(
s_func, image, s_args, prob),
# pylint:enable=g-long-lambda
lambda: image)
return image
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input data reader.
Creates a tf.data.Dataset object from multiple input sstables and use a
provided data parser function to decode the serialized tf.Example and optionally
run data augmentation.
"""
import os
from typing import Any, Callable, List, Optional, Sequence, Union
import gin
from six.moves import map
import tensorflow as tf
from official.common import dataset_fn
from research.object_detection.utils import label_map_util
from official.core import config_definitions as cfg
from official.projects.unified_detector.data_loaders import universal_detection_parser # pylint: disable=unused-import
FuncType = Callable[..., Any]
@gin.configurable(denylist=['is_training'])
class InputFn(object):
"""Input data reader class.
Creates a tf.data.Dataset object from multiple datasets (optionally performs
weighted sampling between different datasets), parses the tf.Example message
using `parser_fn`. The datasets can either be stored in SSTable or TfRecord.
"""
def __init__(self,
is_training: bool,
batch_size: Optional[int] = None,
data_root: str = '',
input_paths: List[str] = gin.REQUIRED,
dataset_type: str = 'tfrecord',
use_sampling: bool = False,
sampling_weights: Optional[Sequence[Union[int, float]]] = None,
cycle_length: Optional[int] = 64,
shuffle_buffer_size: Optional[int] = 512,
parser_fn: Optional[FuncType] = None,
parser_num_parallel_calls: Optional[int] = 64,
max_intra_op_parallelism: Optional[int] = None,
label_map_proto_path: Optional[str] = None,
input_filter_fns: Optional[List[FuncType]] = None,
input_training_filter_fns: Optional[Sequence[FuncType]] = None,
dense_to_ragged_batch: bool = False,
data_validator_fn: Optional[Callable[[Sequence[str]],
None]] = None):
"""Input reader constructor.
Args:
is_training: Boolean indicating TRAIN or EVAL.
batch_size: Input data batch size. Ignored if batch size is passed through
params. In that case, this can be None.
data_root: All the relative input paths are based on this location.
input_paths: Input file patterns.
dataset_type: Can be 'sstable' or 'tfrecord'.
use_sampling: Whether to perform weighted sampling between different
datasets.
sampling_weights: Unnormalized sampling weights. The length should be
equal to `input_paths`.
cycle_length: The number of input Datasets to interleave from in parallel.
If set to None tf.data experimental autotuning is used.
shuffle_buffer_size: The random shuffle buffer size.
parser_fn: The function to run decoding and data augmentation. The
function takes `is_training` as an input, which is passed from here.
parser_num_parallel_calls: The number of parallel calls for `parser_fn`.
The number of CPU cores is the suggested value. If set to None tf.data
experimental autotuning is used.
max_intra_op_parallelism: if set limits the max intra op parallelism of
functions run on slices of the input.
label_map_proto_path: Path to a StringIntLabelMap which will be used to
decode the input data.
input_filter_fns: A list of functions on the dataset points which returns
true for valid data.
input_training_filter_fns: A list of functions on the dataset points which
returns true for valid data used only for training.
dense_to_ragged_batch: Whether to use ragged batching for MPNN format.
data_validator_fn: If not None, used to validate the data specified by
input_paths.
Raises:
ValueError for invalid input_paths.
"""
self._is_training = is_training
if data_root:
# If an input path is absolute this does not change it.
input_paths = [os.path.join(data_root, value) for value in input_paths]
self._input_paths = input_paths
# Disables datasets sampling during eval.
self._batch_size = batch_size
if is_training:
self._use_sampling = use_sampling
else:
self._use_sampling = False
self._sampling_weights = sampling_weights
self._cycle_length = (cycle_length if cycle_length else tf.data.AUTOTUNE)
self._shuffle_buffer_size = shuffle_buffer_size
self._parser_num_parallel_calls = (
parser_num_parallel_calls
if parser_num_parallel_calls else tf.data.AUTOTUNE)
self._max_intra_op_parallelism = max_intra_op_parallelism
self._label_map_proto_path = label_map_proto_path
if label_map_proto_path:
name_to_id = label_map_util.get_label_map_dict(label_map_proto_path)
self._lookup_str_keys = list(name_to_id.keys())
self._lookup_int_values = list(name_to_id.values())
self._parser_fn = parser_fn
self._input_filter_fns = input_filter_fns or []
if is_training and input_training_filter_fns:
self._input_filter_fns.extend(input_training_filter_fns)
self._dataset_type = dataset_type
self._dense_to_ragged_batch = dense_to_ragged_batch
if data_validator_fn is not None:
data_validator_fn(self._input_paths)
@property
def batch_size(self):
return self._batch_size
def __call__(
self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Read and parse input datasets, return a tf.data.Dataset object."""
# TPUEstimator passes the batch size through params.
if params is not None and 'batch_size' in params:
batch_size = params['batch_size']
else:
batch_size = self._batch_size
per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size
with tf.name_scope('input_reader'):
dataset = self._build_dataset_from_records()
dataset_parser_fn = self._build_dataset_parser_fn()
dataset = dataset.map(
dataset_parser_fn, num_parallel_calls=self._parser_num_parallel_calls)
for filter_fn in self._input_filter_fns:
dataset = dataset.filter(filter_fn)
if self._dense_to_ragged_batch:
dataset = dataset.apply(
tf.data.experimental.dense_to_ragged_batch(
batch_size=per_replica_batch_size, drop_remainder=True))
else:
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
def _fetch_dataset(self, filename: str) -> tf.data.Dataset:
"""Fetch dataset depending on type.
Args:
filename: Location of dataset.
Returns:
Tf Dataset.
"""
data_cls = dataset_fn.pick_dataset_fn(self._dataset_type)
data = data_cls([filename])
return data
def _build_dataset_parser_fn(self) -> Callable[..., tf.Tensor]:
"""Depending on label_map and storage type, build a parser_fn."""
# Parse the fetched records to input tensors for model function.
if self._label_map_proto_path:
lookup_initializer = tf.lookup.KeyValueTensorInitializer(
keys=tf.constant(self._lookup_str_keys, dtype=tf.string),
values=tf.constant(self._lookup_int_values, dtype=tf.int32))
name_to_id_table = tf.lookup.StaticHashTable(
initializer=lookup_initializer, default_value=0)
parser_fn = self._parser_fn(
is_training=self._is_training, label_lookup_table=name_to_id_table)
else:
parser_fn = self._parser_fn(is_training=self._is_training)
return parser_fn
def _build_dataset_from_records(self) -> tf.data.Dataset:
"""Build a tf.data.Dataset object from input SSTables.
If the input data come from multiple SSTables, use the user defined sampling
weights to perform sampling. For example, if the sampling weights is
[1., 2.], the second dataset will be sampled twice more often than the first
one.
Returns:
Dataset built from SSTables.
Raises:
ValueError for inability to find SSTable files.
"""
all_file_patterns = []
if self._use_sampling:
for file_pattern in self._input_paths:
all_file_patterns.append([file_pattern])
# Normalize sampling probabilities.
total_weight = sum(self._sampling_weights)
sampling_probabilities = [
float(w) / total_weight for w in self._sampling_weights
]
else:
all_file_patterns.append(self._input_paths)
datasets = []
for file_pattern in all_file_patterns:
filenames = sum(list(map(tf.io.gfile.glob, file_pattern)), [])
if not filenames:
raise ValueError(
f'Error trying to read input files for file pattern {file_pattern}')
# Create a dataset of filenames and shuffle the files. In each epoch,
# the file order is shuffled again. This may help if
# per_host_input_for_training = false on TPU.
dataset = tf.data.Dataset.list_files(
file_pattern, shuffle=self._is_training)
if self._is_training:
dataset = dataset.repeat()
if self._max_intra_op_parallelism:
# Disable intra-op parallelism to optimize for throughput instead of
# latency.
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)
dataset = dataset.interleave(
self._fetch_dataset,
cycle_length=self._cycle_length,
num_parallel_calls=self._cycle_length,
deterministic=(not self._is_training))
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
datasets.append(dataset)
if self._use_sampling:
assert len(datasets) == len(sampling_probabilities)
dataset = tf.data.experimental.sample_from_datasets(
datasets, sampling_probabilities)
else:
dataset = datasets[0]
return dataset
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for GOCR."""
from typing import List, Optional, Sequence, Tuple, Union
import tensorflow as tf
from official.projects.unified_detector.utils.typing import TensorDict
from official.vision.dataloaders import decoder
class TfExampleDecoder(decoder.Decoder):
"""Tensorflow Example proto decoder."""
def __init__(self,
use_instance_mask: bool = False,
additional_class_names: Optional[Sequence[str]] = None,
additional_regression_names: Optional[Sequence[str]] = None,
num_additional_channels: int = 0):
"""Constructor.
keys_to_features is a dictionary mapping the names of the tf.Example
fields to tf features, possibly with defaults.
Uses fixed length for scalars and variable length for vectors.
Args:
use_instance_mask: if False, prevents decoding of the instance mask, which
can take a lot of resources.
additional_class_names: If not none, a list of additional class names. For
additional class name n, named image/object/${n} are expected to be an
int vector of length one, and are mapped to tensor dict key
groundtruth_${n}.
additional_regression_names: If not none, a list of additional regression
output names. For additional class name n, named image/object/${n} are
expected to be a float vector, and are mapped to tensor dict key
groundtruth_${n}.
num_additional_channels: The number of additional channels of information
present in the tf.Example proto.
"""
self._num_additional_channels = num_additional_channels
self._use_instance_mask = use_instance_mask
self.keys_to_features = {}
# Map names in the final tensor dict (output of `self.decode()`) to names in
# tf examples, e.g. 'groundtruth_text' -> 'image/object/text'
self.name_to_key = {}
if use_instance_mask:
self.keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string),
})
# Now we have lists of standard types.
# To add new features, just add entries here.
# The tuple elements are (example name, tensor name, default value).
# If the items_to_handlers part is already set up use None for
# the tensor name.
# There are other tensor names listed as None which we probably
# want to discuss and specify.
scalar_strings = [
('image/encoded', None, ''),
('image/format', None, 'jpg'),
('image/additional_channels/encoded', None, ''),
('image/additional_channels/format', None, 'png'),
('image/label_type', 'label_type', ''),
('image/key', 'key', ''),
('image/source_id', 'source_id', ''),
]
vector_strings = [
('image/attributes', None, ''),
('image/object/text', 'groundtruth_text', ''),
('image/object/encoded_text', 'groundtruth_encoded_text', ''),
('image/object/vertices', 'groundtruth_vertices', ''),
('image/object/object_type', None, ''),
('image/object/language', 'language', ''),
('image/object/reorderer_type', None, ''),
('image/label_map_path', 'label_map_path', '')
]
scalar_ints = [
('image/height', None, 1),
('image/width', None, 1),
('image/channels', None, 3),
]
vector_ints = [
('image/object/classes', 'groundtruth_classes', 0),
('image/object/frame_id', 'frame_id', 0),
('image/object/track_id', 'track_id', 0),
('image/object/content_type', 'groundtruth_content_type', 0),
]
if additional_class_names:
vector_ints += [('image/object/%s' % name, 'groundtruth_%s' % name, 0)
for name in additional_class_names]
# This one is not yet needed:
# scalar_floats = [
# ]
vector_floats = [
('image/object/weight', 'groundtruth_weight', 0),
('image/object/rbox_tl_x', None, 0),
('image/object/rbox_tl_y', None, 0),
('image/object/rbox_width', None, 0),
('image/object/rbox_height', None, 0),
('image/object/rbox_angle', None, 0),
('image/object/bbox/xmin', None, 0),
('image/object/bbox/xmax', None, 0),
('image/object/bbox/ymin', None, 0),
('image/object/bbox/ymax', None, 0),
]
if additional_regression_names:
vector_floats += [('image/object/%s' % name, 'groundtruth_%s' % name, 0)
for name in additional_regression_names]
self._init_scalar_features(scalar_strings, tf.string)
self._init_vector_features(vector_strings, tf.string)
self._init_scalar_features(scalar_ints, tf.int64)
self._init_vector_features(vector_ints, tf.int64)
self._init_vector_features(vector_floats, tf.float32)
def _init_scalar_features(
self,
feature_list: List[Tuple[str, Optional[str], Union[str, int, float]]],
ftype: tf.dtypes.DType) -> None:
for entry in feature_list:
self.keys_to_features[entry[0]] = tf.io.FixedLenFeature(
(), ftype, default_value=entry[2])
if entry[1] is not None:
self.name_to_key[entry[1]] = entry[0]
def _init_vector_features(
self,
feature_list: List[Tuple[str, Optional[str], Union[str, int, float]]],
ftype: tf.dtypes.DType) -> None:
for entry in feature_list:
self.keys_to_features[entry[0]] = tf.io.VarLenFeature(ftype)
if entry[1] is not None:
self.name_to_key[entry[1]] = entry[0]
def _decode_png_instance_masks(self, keys_to_tensors: TensorDict)-> tf.Tensor:
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: A dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
def decode_png_mask(image_buffer):
image = tf.squeeze(
tf.image.decode_image(image_buffer, channels=1), axis=2)
image.set_shape([None, None])
image = tf.to_float(tf.greater(image, 0))
return image
png_masks = keys_to_tensors['image/object/mask']
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
if isinstance(png_masks, tf.SparseTensor):
png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
return tf.cond(
tf.greater(tf.size(png_masks), 0),
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))
def _decode_image(self,
parsed_tensors: TensorDict,
channel: int = 3) -> TensorDict:
"""Decodes the image and set its shape (H, W are dynamic; C is fixed)."""
image = tf.io.decode_image(parsed_tensors['image/encoded'],
channels=channel)
image.set_shape([None, None, channel])
return {'image': image}
def _decode_additional_channels(self,
parsed_tensors: TensorDict,
channel: int = 3) -> TensorDict:
"""Decodes the additional channels and set its static shape."""
channels = tf.io.decode_image(
parsed_tensors['image/additional_channels/encoded'], channels=channel)
channels.set_shape([None, None, channel])
return {'additional_channels': channels}
def _decode_boxes(self, parsed_tensors: TensorDict) -> TensorDict:
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
return {
'groundtruth_aligned_boxes': tf.stack([ymin, xmin, ymax, xmax], axis=-1)
}
def _decode_rboxes(self, parsed_tensors: TensorDict) -> TensorDict:
"""Concat rbox coordinates: [left, top, box_width, box_height, angle]."""
top_left_x = parsed_tensors['image/object/rbox_tl_x']
top_left_y = parsed_tensors['image/object/rbox_tl_y']
width = parsed_tensors['image/object/rbox_width']
height = parsed_tensors['image/object/rbox_height']
angle = parsed_tensors['image/object/rbox_angle']
return {
'groundtruth_boxes':
tf.stack([top_left_x, top_left_y, width, height, angle], axis=-1)
}
def _decode_masks(self, parsed_tensors: TensorDict) -> TensorDict:
"""Decode a set of PNG masks to the tf.float32 tensors."""
def _decode_png_mask(png_bytes):
mask = tf.squeeze(
tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
mask = tf.cast(mask, dtype=tf.float32)
mask.set_shape([None, None])
return mask
height = parsed_tensors['image/height']
width = parsed_tensors['image/width']
masks = parsed_tensors['image/object/mask']
masks = tf.cond(
pred=tf.greater(tf.size(input=masks), 0),
true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
return {'groundtruth_instance_masks': masks}
def decode(self, tf_example_string_tensor: tf.string):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: A string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary contains a subset of the following, depends on the inputs:
image: A uint8 tensor of shape [height, width, 3] containing the image.
source_id: A string tensor contains image fingerprint.
key: A string tensor contains the unique sha256 hash key.
label_type: Either `full` or `partial`. `full` means all the text are
fully labeled, `partial` otherwise. Currently, this is used by E2E
model. If an input image is fully labeled, we update the weights of
both the detection and the recognizer. Otherwise, only recognizer part
of the model is trained.
groundtruth_text: A string tensor list, the original transcriptions.
groundtruth_encoded_text: A string tensor list, the class ids for the
atoms in the text, after applying the reordering algorithm, in string
form. For example "90,71,85,69,86,85,93,90,71,91,1,71,85,93,90,71".
This depends on the class label map provided to the conversion
program. These are 0 based, with -1 for OOV symbols.
groundtruth_classes: A int32 tensor of shape [num_boxes] contains the
class id. Note this is 1 based, 0 is reserved for background class.
groundtruth_content_type: A int32 tensor of shape [num_boxes] contains
the content type. Values correspond to PageLayoutEntity::ContentType.
groundtruth_weight: A int32 tensor of shape [num_boxes], either 0 or 1.
If a region has weight 0, it will be ignored when computing the
losses.
groundtruth_boxes: A float tensor of shape [num_boxes, 5] contains the
groundtruth rotated rectangles. Each row is in [left, top, box_width,
box_height, angle] order, absolute coordinates are used.
groundtruth_aligned_boxes: A float tensor of shape [num_boxes, 4]
contains the groundtruth axis-aligned rectangles. Each row is in
[ymin, xmin, ymax, xmax] order. Currently, this is used to store
groundtruth symbol boxes.
groundtruth_vertices: A string tensor list contains encoded normalized
box or polygon coordinates. E.g. `x1,y1,x2,y2,x3,y3,x4,y4`.
groundtruth_instance_masks: A float tensor of shape [num_boxes, height,
width] contains binarized image sized instance segmentation masks.
`1.0` for positive region, `0.0` otherwise. None if not in tfe.
frame_id: A int32 tensor of shape [num_boxes], either `0` or `1`.
`0` means object comes from first image, `1` means second.
track_id: A int32 tensor of shape [num_boxes], where value indicates
identity across frame indices.
additional_channels: A uint8 tensor of shape [H, W, C] representing some
features.
"""
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=self.keys_to_features)
for k in parsed_tensors:
if isinstance(parsed_tensors[k], tf.SparseTensor):
if parsed_tensors[k].dtype == tf.string:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value='')
else:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value=0)
decoded_tensors = {}
decoded_tensors.update(self._decode_image(parsed_tensors))
decoded_tensors.update(self._decode_rboxes(parsed_tensors))
decoded_tensors.update(self._decode_boxes(parsed_tensors))
if self._use_instance_mask:
decoded_tensors[
'groundtruth_instance_masks'] = self._decode_png_instance_masks(
parsed_tensors)
if self._num_additional_channels:
decoded_tensors.update(self._decode_additional_channels(
parsed_tensors, self._num_additional_channels))
# other attributes:
for key in self.name_to_key:
if key not in decoded_tensors:
decoded_tensors[key] = parsed_tensors[self.name_to_key[key]]
if 'groundtruth_instance_masks' not in decoded_tensors:
decoded_tensors['groundtruth_instance_masks'] = None
return decoded_tensors
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data parser for universal detector."""
import enum
import functools
from typing import Any, Tuple
import gin
import tensorflow as tf
from official.projects.unified_detector.data_loaders import autoaugment
from official.projects.unified_detector.data_loaders import tf_example_decoder
from official.projects.unified_detector.utils import utilities
from official.projects.unified_detector.utils.typing import NestedTensorDict
from official.projects.unified_detector.utils.typing import TensorDict
@gin.constants_from_enum
class DetectionClass(enum.IntEnum):
"""As in `PageLayoutEntity.EntityType`."""
WORD = 0
LINE = 2
PARAGRAPH = 3
BLOCK = 4
NOT_ANNOTATED_ID = 8
def _erase(mask: tf.Tensor,
feature: tf.Tensor,
min_val: float = 0.,
max_val: float = 256.) -> tf.Tensor:
"""Erase the feature maps with a mask.
Erase feature maps with a mask and replace the erased area with uniform random
noise. The mask can have different size from the feature maps.
Args:
mask: an (h, w) binay mask for pixels to erase with. Value 1 represents
pixels to erase.
feature: the (H, W, C) feature maps to erase from.
min_val: The minimum value of random noise.
max_val: The maximum value of random noise.
Returns:
The (H, W, C) feature maps, with pixels in mask replaced with noises. It's
equal to mask * noise + (1 - mask) * feature.
"""
h, w, c = utilities.resolve_shape(feature)
resized_mask = tf.image.resize(
tf.tile(tf.expand_dims(tf.cast(mask, tf.float32), -1), (1, 1, c)), (h, w))
erased = tf.where(
condition=(resized_mask > 0.5),
x=tf.cast(tf.random.uniform((h, w, c), min_val, max_val), feature.dtype),
y=feature)
return erased
@gin.configurable(denylist=['is_training'])
class UniDetectorParserFn(object):
"""Data parser for universal detector."""
def __init__(
self,
is_training: bool,
output_dimension: int = 1025,
mask_dimension: int = -1,
max_num_instance: int = 128,
rot90_probability: float = 0.5,
use_color_distortion: bool = True,
randaug_mag: float = 5.,
randaug_std: float = 0.5,
randaug_layer: int = 2,
randaug_prob: float = 0.5,
use_cropping: bool = True,
crop_min_scale: float = 0.5,
crop_max_scale: float = 1.5,
crop_min_aspect: float = 4 / 5,
crop_max_aspect: float = 5 / 4,
is_shape_defined: bool = True,
use_tpu: bool = True,
detection_unit: DetectionClass = DetectionClass.LINE,
):
"""Constructor.
Args:
is_training: bool indicating TRAIN or EVAL.
output_dimension: The size of input images.
mask_dimension: The size of the output mask. If negative or zero, it will
be set the same as output_dimension.
max_num_instance: The maximum number of instances to output. If it's
negative, padding or truncating will not be performed.
rot90_probability: The probability of rotating multiples of 90 degrees.
use_color_distortion: Whether to apply color distortions to images (via
autoaugment).
randaug_mag: (autoaugment parameter) Color distortion magnitude. Note
that, this value should be set conservatively, as some color distortions
can easily make text illegible e.g. posterize.
randaug_std: (autoaugment parameter) Randomness in color distortion
magnitude.
randaug_layer: (autoaugment parameter) Number of color distortion
operations.
randaug_prob: (autoaugment parameter) Probabilily of applying each
distortion operation.
use_cropping: Bool, whether to use random cropping and resizing in
training.
crop_min_scale: The minimum scale of a random crop.
crop_max_scale: The maximum scale of a random crop. If >1, it means the
images are downsampled.
crop_min_aspect: The minimum aspect ratio of a random crop.
crop_max_aspect: The maximum aspect ratio of a random crop.
is_shape_defined: Whether to define the static shapes for all features and
labels. This must be set to True in TPU training as it requires static
shapes for all tensors.
use_tpu: Whether the inputs are fed to a TPU device.
detection_unit: Whether word or line (or else) is regarded as an entity.
The instance masks will be at word or line level.
"""
if is_training and max_num_instance < 0:
raise ValueError('In TRAIN mode, padding/truncation is required.')
self._is_training = is_training
self._output_dimension = output_dimension
self._mask_dimension = (
mask_dimension if mask_dimension > 0 else output_dimension)
self._max_num_instance = max_num_instance
self._decoder = tf_example_decoder.TfExampleDecoder(
num_additional_channels=3, additional_class_names=['parent'])
self._use_color_distortion = use_color_distortion
self._rot90_probability = rot90_probability
self._randaug_mag = randaug_mag
self._randaug_std = randaug_std
self._randaug_layer = randaug_layer
self._randaug_prob = randaug_prob
self._use_cropping = use_cropping
self._crop_min_scale = crop_min_scale
self._crop_max_scale = crop_max_scale
self._crop_min_aspect = crop_min_aspect
self._crop_max_aspect = crop_max_aspect
self._is_shape_defined = is_shape_defined
self._use_tpu = use_tpu
self._detection_unit = detection_unit
def __call__(self, value: str) -> Tuple[TensorDict, NestedTensorDict]:
"""Parsing the data.
Args:
value: The serialized data sample.
Returns:
Two dicts for features and labels.
features:
'source_id': id of the sample; only in EVAL mode
'images': the normalized images, (output_dimension, output_dimension, 3)
labels:
See `_prepare_labels` for its content.
"""
data = self._decoder.decode(value)
features = {}
labels = {}
self._preprocess(data, features, labels)
self._rot90k(data, features, labels)
self._crop_and_resize(data, features, labels)
self._color_distortion_and_normalize(data, features, labels)
self._prepare_labels(data, features, labels)
self._define_shapes(features, labels)
return features, labels
def _preprocess(self, data: TensorDict, features: TensorDict,
unused_labels: TensorDict):
"""All kinds of preprocessing of the decoded data dict."""
# (1) Decode the entity_id_mask: a H*W*1 mask, each pixel equals to
# (1 + position) of the entity in the GT entity list. The IDs
# (which can be larger than 255) are stored in the last two channels.
data['additional_channels'] = tf.cast(data['additional_channels'], tf.int32)
entity_id_mask = (
data['additional_channels'][:, :, -2:-1] * 256 +
data['additional_channels'][:, :, -1:])
data['entity_id_mask'] = entity_id_mask
# (2) Write image id. Used in evaluation.
if not self._use_tpu:
features['source_id'] = data['source_id']
# (3) Block mask: area without annotation
data['image'] = _erase(
data['additional_channels'][:, :, 0],
data['image'],
min_val=0.,
max_val=256.)
def _rot90k(self, data: TensorDict, unused_features: TensorDict,
unused_labels: TensorDict):
"""Rotate the image, gt_bboxes, masks by 90k degrees."""
if not self._is_training:
return
rotate_90_choice = tf.random.uniform([])
def _rotate():
"""Rotation.
These will be rotated:
image,
rbox,
entity_id_mask,
TODO(longshangbang): rotate vertices.
Returns:
The rotated tensors of the above fields.
"""
k = tf.random.uniform([], 1, 4, dtype=tf.int32)
h, w, _ = utilities.resolve_shape(data['image'])
# Image
rotated_img = tf.image.rot90(data['image'], k=k, name='image_rot90k')
# Box
rotate_box_op = functools.partial(
utilities.rotate_rboxes90,
rboxes=data['groundtruth_boxes'],
image_width=w,
image_height=h)
rotated_boxes = tf.switch_case(
k - 1, # Indices start with 1.
branch_fns=[
lambda: rotate_box_op(rotation_count=1),
lambda: rotate_box_op(rotation_count=2),
lambda: rotate_box_op(rotation_count=3)
])
# Mask
rotated_mask = tf.image.rot90(
data['entity_id_mask'], k=k, name='mask_rot90k')
return rotated_img, rotated_boxes, rotated_mask
# pylint: disable=g-long-lambda
(data['image'], data['groundtruth_boxes'],
data['entity_id_mask']) = tf.cond(
rotate_90_choice < self._rot90_probability, _rotate, lambda:
(data['image'], data['groundtruth_boxes'], data['entity_id_mask']))
# pylint: enable=g-long-lambda
def _crop_and_resize(self, data: TensorDict, unused_features: TensorDict,
unused_labels: TensorDict):
"""Perform random cropping and resizing."""
# TODO(longshangbang): resize & translate box as well
# TODO(longshangbang): resize & translate vertices as well
# Get cropping target.
h, w = utilities.resolve_shape(data['image'])[:2]
left, top, crop_w, crop_h, pad_w, pad_h = self._get_crop_box(
tf.cast(h, tf.float32), tf.cast(w, tf.float32))
# Crop the image. (Pad the images if the crop box is larger than image.)
if self._is_training:
# padding left, top, right, bottom
pad_left = tf.random.uniform([], 0, pad_w + 1, dtype=tf.int32)
pad_top = tf.random.uniform([], 0, pad_h + 1, dtype=tf.int32)
else:
pad_left = 0
pad_top = 0
cropped_img = tf.image.crop_to_bounding_box(data['image'], top, left,
crop_h, crop_w)
padded_img = tf.pad(
cropped_img,
[[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]],
constant_values=127)
# Resize images
data['resized_image'] = tf.image.resize(
padded_img, (self._output_dimension, self._output_dimension))
data['resized_image'] = tf.cast(data['resized_image'], tf.uint8)
# Crop the masks
cropped_masks = tf.image.crop_to_bounding_box(data['entity_id_mask'], top,
left, crop_h, crop_w)
padded_masks = tf.pad(
cropped_masks,
[[pad_top, pad_h - pad_top], [pad_left, pad_w - pad_left], [0, 0]])
# Resize masks
data['resized_masks'] = tf.image.resize(
padded_masks, (self._mask_dimension, self._mask_dimension),
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
data['resized_masks'] = tf.squeeze(data['resized_masks'], -1)
def _get_crop_box(
self, h: tf.Tensor,
w: tf.Tensor) -> Tuple[Any, Any, tf.Tensor, tf.Tensor, Any, Any]:
"""Get the cropping box.
Args:
h: The height of the image to crop. Should be float type.
w: The width of the image to crop. Should be float type.
Returns:
A tuple representing (left, top, crop_w, crop_h, pad_w, pad_h).
Then in `self._crop_and_resize`, a crop will be extracted with bounding
box from top-left corner (left, top) and with size (crop_w, crop_h). This
crop will then be padded with (pad_w, pad_h) to square sizes.
The outputs also are re-cast to int32 type.
"""
if not self._is_training or not self._use_cropping:
# cast back to integers.
w = tf.cast(w, tf.int32)
h = tf.cast(h, tf.int32)
side = tf.maximum(w, h)
return 0, 0, w, h, side - w, side - h
# Get box size
scale = tf.random.uniform([], self._crop_min_scale, self._crop_max_scale)
max_edge = tf.maximum(w, h)
long_edge = max_edge * scale
sqrt_aspect_ratio = tf.math.sqrt(
tf.random.uniform([], self._crop_min_aspect, self._crop_max_aspect))
box_h = long_edge / sqrt_aspect_ratio
box_w = long_edge * sqrt_aspect_ratio
# Get box location
left = tf.random.uniform([], 0., tf.maximum(0., w - box_w))
top = tf.random.uniform([], 0., tf.maximum(0., h - box_h))
# Get crop & pad
crop_w = tf.minimum(box_w, w - left)
crop_h = tf.minimum(box_h, h - top)
pad_w = box_w - crop_w
pad_h = box_h - crop_h
return (tf.cast(left, tf.int32), tf.cast(top, tf.int32),
tf.cast(crop_w, tf.int32), tf.cast(crop_h, tf.int32),
tf.cast(pad_w, tf.int32), tf.cast(pad_h, tf.int32))
def _color_distortion_and_normalize(self, data: TensorDict,
features: TensorDict,
unused_labels: TensorDict):
"""Distort colors."""
if self._is_training and self._use_color_distortion:
data['resized_image'] = autoaugment.distort_image_with_randaugment(
data['resized_image'], self._randaug_layer, self._randaug_mag,
self._randaug_std, True, self._randaug_prob, True)
# Normalize
features['images'] = utilities.normalize_image_to_range(
data['resized_image'])
def _prepare_labels(self, data: TensorDict, features: TensorDict,
labels: TensorDict):
"""This function prepares the labels.
These following targets are added to labels['segmentation_output']:
'gt_word_score': A (h, w) float32 mask for textness score. 1 for word,
0 for bkg.
These following targets are added to labels['instance_labels']:
'num_instance': A float scalar tensor for the total number of
instances. It is bounded by the maximum number of instances allowed.
It includes the special background instance, so it equals to
(1 + entity numbers).
'masks': A (h, w) int32 mask for entity IDs. The value of each pixel is
the id of the entity it belongs to. A value of `0` means the bkg mask.
'classes': A (max_num,) int tensor indicating the classes of each
instance:
2 for background
1 for text entity
0 for non-object
'masks_sizes': A (max_num,) float tensor for the size of all masks.
'gt_weights': Whether it's difficult / does not have text annotation.
These following targets are added to labels['paragraph_labels']:
'paragraph_ids': A (max_num,) integer tensor for paragprah id. if `-1`,
then no paragraph label for this text.
'has_para_ids': A float scalar; 1.0 if the sample has paragraph labels.
Args:
data: The data dictionary.
features: The feature dict.
labels: The label dict.
"""
# Segmentation labels:
self._get_segmentation_labels(data, features, labels)
# Instance labels:
self._get_instance_labels(data, features, labels)
def _get_segmentation_labels(self, data: TensorDict,
unused_features: TensorDict,
labels: NestedTensorDict):
labels['segmentation_output'] = {
'gt_word_score': tf.cast((data['resized_masks'] > 0), tf.float32)
}
def _get_instance_labels(self, data: TensorDict, features: TensorDict,
labels: NestedTensorDict):
"""Generate the labels for text entity detection."""
labels['instance_labels'] = {}
# (1) Depending on `detection_unit`:
# Convert the word-id map to line-id map or use the word-id map directly
# Word entity ids start from 1 in the map, so pad a -1 at the beginning of
# the parent list to counter this offset.
padded_parent = tf.concat(
[tf.constant([-1]),
tf.cast(data['groundtruth_parent'], tf.int32)], 0)
if self._detection_unit == DetectionClass.WORD:
entity_id_mask = data['resized_masks']
elif self._detection_unit == DetectionClass.LINE:
# The pixel value is entity_id + 1, shape = [H, W]; 0 for background.
# correctness:
# 0s in data['resized_masks'] --> padded_parent[0] == -1
# i-th entity in plp.entities --> i+1 in data['resized_masks']
# --> padded_parent[i+1]
# --> data['groundtruth_parent'][i]
# --> the parent of i-th entity
entity_id_mask = tf.gather(padded_parent, data['resized_masks']) + 1
elif self._detection_unit == DetectionClass.PARAGRAPH:
# directly segmenting paragraphs; two hops here.
entity_id_mask = tf.gather(padded_parent, data['resized_masks']) + 1
entity_id_mask = tf.gather(padded_parent, entity_id_mask) + 1
else:
raise ValueError(f'No such detection unit: {self._detection_unit}')
data['entity_id_mask'] = entity_id_mask
# (2) Get individual masks for entities.
entity_selection_mask = tf.equal(data['groundtruth_classes'],
self._detection_unit)
num_all_entity = utilities.resolve_shape(data['groundtruth_classes'])[0]
# entity_ids is a 1-D tensor for IDs of all entities of a certain type.
entity_ids = tf.boolean_mask(
tf.range(num_all_entity, dtype=tf.int32), entity_selection_mask) # (N,)
# +1 to match the entity ids in entity_id_mask
entity_ids = tf.reshape(entity_ids, (-1, 1, 1)) + 1
individual_masks = tf.expand_dims(entity_id_mask, 0)
individual_masks = tf.equal(entity_ids, individual_masks) # (N, H, W), bool
# TODO(longshangbang): replace with real mask sizes computing.
# Currently, we use full-resolution masks for individual_masks. In order to
# compute mask sizes, we need to convert individual_masks to int/float type.
# This will cause OOM because the mask is too large.
masks_sizes = tf.cast(
tf.reduce_any(individual_masks, axis=[1, 2]), tf.float32)
# remove empty masks (usually caused by cropping)
non_empty_masks_ids = tf.not_equal(masks_sizes, 0)
valid_masks = tf.boolean_mask(individual_masks, non_empty_masks_ids)
valid_entity_ids = tf.boolean_mask(entity_ids, non_empty_masks_ids)[:, 0, 0]
# (3) Write num of instance
num_instance = tf.reduce_sum(tf.cast(non_empty_masks_ids, tf.float32))
num_instance_and_bkg = num_instance + 1
if self._max_num_instance >= 0:
num_instance_and_bkg = tf.minimum(num_instance_and_bkg,
self._max_num_instance)
labels['instance_labels']['num_instance'] = num_instance_and_bkg
# (4) Write instance masks
num_entity_int = tf.cast(num_instance, tf.int32)
max_num_entities = self._max_num_instance - 1 # Spare 1 for bkg.
pad_num = tf.maximum(max_num_entities - num_entity_int, 0)
padded_valid_masks = tf.pad(valid_masks, [[0, pad_num], [0, 0], [0, 0]])
# If there are more instances than allowed, randomly sample some.
# `random_selection_mask` is a 0/1 array; the maximum number of 1 is
# `self._max_num_instance`; if not bound, it's an array with all 1s.
if self._max_num_instance >= 0:
padded_size = num_entity_int + pad_num
random_selection = tf.random.uniform((padded_size,), dtype=tf.float32)
selected_indices = tf.math.top_k(random_selection, k=max_num_entities)[1]
random_selection_mask = tf.scatter_nd(
indices=tf.expand_dims(selected_indices, axis=-1),
updates=tf.ones((max_num_entities,), dtype=tf.bool),
shape=(padded_size,))
else:
random_selection_mask = tf.ones((num_entity_int,), dtype=tf.bool)
random_discard_mask = tf.logical_not(random_selection_mask)
kept_masks = tf.boolean_mask(padded_valid_masks, random_selection_mask)
erased_masks = tf.boolean_mask(padded_valid_masks, random_discard_mask)
erased_masks = tf.cast(tf.reduce_any(erased_masks, axis=0), tf.float32)
# erase text instances that are obmitted.
features['images'] = _erase(erased_masks, features['images'], -1., 1.)
labels['segmentation_output']['gt_word_score'] *= 1. - erased_masks
kept_masks_and_bkg = tf.concat(
[
tf.math.logical_not(
tf.reduce_any(kept_masks, axis=0, keepdims=True)), # bkg
kept_masks,
],
0)
labels['instance_labels']['masks'] = tf.argmax(kept_masks_and_bkg, axis=0)
# (5) Write mask size
# TODO(longshangbang): replace with real masks sizes
masks_sizes = tf.cast(
tf.reduce_any(kept_masks_and_bkg, axis=[1, 2]), tf.float32)
labels['instance_labels']['masks_sizes'] = masks_sizes
# (6) Write classes.
classes = tf.ones((num_instance,), dtype=tf.int32)
classes = tf.concat([tf.constant(2, tf.int32, (1,)), classes], 0) # bkg
if self._max_num_instance >= 0:
classes = utilities.truncate_or_pad(classes, self._max_num_instance, 0)
labels['instance_labels']['classes'] = classes
# (7) gt-weights
selected_ids = tf.boolean_mask(valid_entity_ids,
random_selection_mask[:num_entity_int])
if self._detection_unit != DetectionClass.PARAGRAPH:
gt_text = tf.gather(data['groundtruth_text'], selected_ids - 1)
gt_weights = tf.cast(tf.strings.length(gt_text) > 0, tf.float32)
else:
text_types = tf.concat(
[
tf.constant([8]),
tf.cast(data['groundtruth_content_type'], tf.int32),
# TODO(longshangbang): temp solution for tfes with no para labels
tf.constant(8, shape=(1000,)),
],
0)
para_types = tf.gather(text_types, selected_ids)
gt_weights = tf.cast(
tf.not_equal(para_types, NOT_ANNOTATED_ID), tf.float32)
gt_weights = tf.concat([tf.constant(1., shape=(1,)), gt_weights], 0) # bkg
if self._max_num_instance >= 0:
gt_weights = utilities.truncate_or_pad(
gt_weights, self._max_num_instance, 0)
labels['instance_labels']['gt_weights'] = gt_weights
# (8) get paragraph label
# In this step, an array `{p_i}` is generated. `p_i` is an integer that
# indicates the group of paragraph which i-th text belongs to. `p_i` == -1
# if this instance is non-text or it has no paragraph labels.
# word -> line -> paragraph
if self._detection_unit == DetectionClass.WORD:
num_hop = 2
elif self._detection_unit == DetectionClass.LINE:
num_hop = 1
elif self._detection_unit == DetectionClass.PARAGRAPH:
num_hop = 0
else:
raise ValueError(f'No such detection unit: {self._detection_unit}. '
'Note that this error should have been raised in '
'previous lines, not here!')
para_ids = tf.identity(selected_ids) # == id in plp + 1
for _ in range(num_hop):
para_ids = tf.gather(padded_parent, para_ids) + 1
text_types = tf.concat(
[
tf.constant([8]),
tf.cast(data['groundtruth_content_type'], tf.int32),
# TODO(longshangbang): tricks for tfes that have not para labels
tf.constant(8, shape=(1000,)),
],
0)
para_types = tf.gather(text_types, para_ids)
para_ids = para_ids - 1 # revert to id in plp.entities; -1 for no labels
valid_para = tf.cast(tf.not_equal(para_types, NOT_ANNOTATED_ID), tf.int32)
para_ids = valid_para * para_ids + (1 - valid_para) * (-1)
para_ids = tf.concat([tf.constant([-1]), para_ids], 0) # add bkg
has_para_ids = tf.cast(tf.reduce_sum(valid_para) > 0, tf.float32)
if self._max_num_instance >= 0:
para_ids = utilities.truncate_or_pad(
para_ids, self._max_num_instance, 0, -1)
labels['paragraph_labels'] = {
'paragraph_ids': para_ids,
'has_para_ids': has_para_ids
}
def _define_shapes(self, features: TensorDict, labels: TensorDict):
"""Define the tensor shapes for TPU compiling."""
if not self._is_shape_defined:
return
features['images'] = tf.ensure_shape(
features['images'], (self._output_dimension, self._output_dimension, 3))
labels['segmentation_output']['gt_word_score'] = tf.ensure_shape(
labels['segmentation_output']['gt_word_score'],
(self._mask_dimension, self._mask_dimension))
labels['instance_labels']['num_instance'] = tf.ensure_shape(
labels['instance_labels']['num_instance'], [])
if self._max_num_instance >= 0:
labels['instance_labels']['masks_sizes'] = tf.ensure_shape(
labels['instance_labels']['masks_sizes'], (self._max_num_instance,))
labels['instance_labels']['masks'] = tf.ensure_shape(
labels['instance_labels']['masks'],
(self._mask_dimension, self._mask_dimension))
labels['instance_labels']['classes'] = tf.ensure_shape(
labels['instance_labels']['classes'], (self._max_num_instance,))
labels['instance_labels']['gt_weights'] = tf.ensure_shape(
labels['instance_labels']['gt_weights'], (self._max_num_instance,))
labels['paragraph_labels']['paragraph_ids'] = tf.ensure_shape(
labels['paragraph_labels']['paragraph_ids'],
(self._max_num_instance,))
labels['paragraph_labels']['has_para_ids'] = tf.ensure_shape(
labels['paragraph_labels']['has_para_ids'], [])
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap external code in gin."""
import gin
import gin.tf.external_configurables
import tensorflow as tf
# Tensorflow.
gin.external_configurable(tf.keras.layers.experimental.SyncBatchNormalization)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Universal detector implementation."""
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import gin
import tensorflow as tf
from deeplab2 import config_pb2
from deeplab2.model.decoder import max_deeplab as max_deeplab_head
from deeplab2.model.encoder import axial_resnet_instances
from deeplab2.model.loss import matchers_ops
from official.legacy.transformer import transformer
from official.projects.unified_detector.utils import typing
from official.projects.unified_detector.utils import utilities
EPSILON = 1e-6
@gin.configurable
def universal_detection_loss_weights(
loss_segmentation_word: float = 1e0,
loss_inst_dist: float = 1e0,
loss_mask_id: float = 1e-4,
loss_pq: float = 3e0,
loss_para: float = 1e0) -> Dict[str, float]:
"""A function that returns a dict for the weights of loss terms."""
return {
"loss_segmentation_word": loss_segmentation_word,
"loss_inst_dist": loss_inst_dist,
"loss_mask_id": loss_mask_id,
"loss_pq": loss_pq,
"loss_para": loss_para,
}
@gin.configurable
class LayerNorm(tf.keras.layers.LayerNormalization):
"""A wrapper to allow passing the `training` argument.
The normalization layers in the MaX-DeepLab implementation are passed with
the `training` argument. This wrapper enables the usage of LayerNorm.
"""
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor:
del training
return super().call(inputs)
@gin.configurable
def get_max_deep_lab_backbone(num_slots: int = 128):
return axial_resnet_instances.get_model(
"max_deeplab_s",
bn_layer=LayerNorm,
block_group_config={
"drop_path_schedule": "linear",
"axial_use_recompute_grad": False
},
backbone_use_transformer_beyond_stride=16,
extra_decoder_use_transformer_beyond_stride=16,
num_mask_slots=num_slots,
max_num_mask_slots=num_slots)
@gin.configurable
class UniversalDetector(tf.keras.layers.Layer):
"""Univeral Detector."""
loss_items = ("loss_pq", "loss_inst_dist", "loss_para", "loss_mask_id",
"loss_segmentation_word")
def __init__(self,
backbone_fn: tf.keras.layers.Layer = get_max_deep_lab_backbone,
mask_threshold: float = 0.4,
class_threshold: float = 0.5,
filter_area: float = 32,
**kwargs: Any):
"""Constructor.
Args:
backbone_fn: The function to initialize a backbone.
mask_threshold: Masks are thresholded with this value.
class_threshold: Classification heads are thresholded with this value.
filter_area: In inference, detections with area smaller than this
threshold will be removed.
**kwargs: other keyword arguments passed to the base class.
"""
super().__init__(**kwargs)
# Model
self._backbone_fn = backbone_fn()
self._decoder = _get_decoder_head()
self._class_embed_head, self._para_embed_head = _get_embed_head()
self._para_head, self._para_proj = _get_para_head()
# Losses
# self._max_deeplab_loss = _get_max_deeplab_loss()
self._loss_weights = universal_detection_loss_weights()
# Post-processing
self._mask_threshold = mask_threshold
self._class_threshold = class_threshold
self._filter_area = filter_area
def _preprocess_labels(self, labels: typing.TensorDict):
# Preprocessing
# Converted the integer mask to one-hot embedded masks.
num_instances = utilities.resolve_shape(
labels["instance_labels"]["masks_sizes"])[1]
labels["instance_labels"]["masks"] = tf.one_hot(
labels["instance_labels"]["masks"],
depth=num_instances,
axis=1,
dtype=tf.float32) # (B, N, H, W)
def compute_losses(
self, labels: typing.NestedTensorDict, outputs: typing.NestedTensorDict
) -> Tuple[tf.Tensor, typing.NestedTensorDict]:
"""Computes the loss.
Args:
labels: A dictionary of ground-truth labels.
outputs: Output from self.call().
Returns:
A scalar total loss tensor and a dictionary for individual losses.
"""
loss_dict = {}
self._preprocess_labels(labels)
# Main loss: PQ loss.
_entity_mask_loss(loss_dict, labels["instance_labels"],
outputs["instance_output"])
# Auxiliary loss 1: semantic loss
_semantic_loss(loss_dict, labels["segmentation_output"],
outputs["segmentation_output"])
# Auxiliary loss 2: instance discrimination
_instance_discrimination_loss(loss_dict, labels["instance_labels"], outputs)
# Auxiliary loss 3: mask id
_mask_id_xent_loss(loss_dict, labels["instance_labels"], outputs)
# Auxiliary loss 4: paragraph grouping
_paragraph_grouping_loss(loss_dict, labels, outputs)
weighted_loss = [self._loss_weights[k] * v for k, v in loss_dict.items()]
total_loss = sum(weighted_loss)
return total_loss, loss_dict
def call(self,
features: typing.TensorDict,
training: bool = False) -> typing.NestedTensorDict:
"""Forward pass of the model.
Args:
features: The input features: {"images": tf.Tensor}. Shape = [B, H, W, C]
training: Whether it's training mode.
Returns:
A dictionary of output with this structure:
{
"max_deep_lab": {
All the max deeplab outputs are here, including both backbone and
decoder.
}
"segmentation_output": {
"word_score": tf.Tensor, [B, h, w],
}
"instance_output": {
"cls_logits": tf.Tensor, [B, N, C],
"mask_id_logits": tf.Tensor, [B, H, W, N],
"cls_prob": tf.Tensor, [B, N, C],
"mask_id_prob": tf.Tensor, [B, H, W, N],
}
"postprocessed": {
"classes": A (B, N) tensor for the class ids. Zero for non-firing
slots.
"binary_masks": A (B, H, W, N) tensor for the N binary masks. Masks
for void cls are set to zero.
"confidence": A (B, N) float tensor for the confidence of "classes".
"mask_area": A (B, N) float tensor for the area of each mask.
}
"transformer_group_feature": (B, N, C) float tensor (normalized),
"para_affinity": (B, N, N) float tensor.
}
Class-0 is for void. Class-(C-1) is for background. Class-1~(C-2) is for
valid classes.
"""
# backbone
backbone_output = self._backbone_fn(features["images"], training)
# split instance embedding and paragraph embedding;
# then perform paragraph grouping
para_fts = self._get_para_outputs(backbone_output, training)
affinity = tf.linalg.matmul(para_fts, para_fts, transpose_b=True)
# text detection head
decoder_output = self._decoder(backbone_output, training)
output_dict = {
"max_deep_lab": decoder_output,
"transformer_group_feature": para_fts,
"para_affinity": affinity,
}
input_shape = utilities.resolve_shape(features["images"])
self._get_semantic_outputs(output_dict, input_shape)
self._get_instance_outputs(output_dict, input_shape)
self._postprocess(output_dict)
return output_dict
def _get_para_outputs(self, outputs: typing.TensorDict,
training: bool) -> tf.Tensor:
"""Apply the paragraph head.
This function first splits the features for instance classification and
instance grouping. Then, the additional grouping branch (transformer layers)
is applied to further encode the grouping features. Finally, a tensor of
normalized grouping features is returned.
Args:
outputs: output dictionary from the backbone.
training: training / eval mode mark.
Returns:
The normalized paragraph embedding vector of shape (B, N, C).
"""
# Project the object embeddings into classification feature and grouping
# feature.
fts = outputs["transformer_class_feature"] # B,N,C
class_feature = self._class_embed_head(fts, training)
group_feature = self._para_embed_head(fts, training)
outputs["transformer_class_feature"] = class_feature
outputs["transformer_group_feature"] = group_feature
# Feed the grouping features into additional group encoding branch.
# First we need to build the attention_bias which is used the standard
# transformer encoder.
input_shape = utilities.resolve_shape(group_feature)
b = input_shape[0]
n = int(input_shape[1])
seq_len = tf.constant(n, shape=(b,))
padding_mask = utilities.get_padding_mask_from_valid_lengths(
seq_len, n, tf.float32)
attention_bias = utilities.get_transformer_attention_bias(padding_mask)
group_feature = self._para_proj(
self._para_head(group_feature, attention_bias, None, training))
return tf.math.l2_normalize(group_feature, axis=-1)
def _get_semantic_outputs(self, outputs: typing.NestedTensorDict,
input_shape: tf.TensorShape):
"""Add `segmentation_output` to outputs.
Args:
outputs: A dictionary of outputs.
input_shape: The shape of the input images.
"""
h, w = input_shape[1:3]
# B, H/4, W/4, C
semantic_logits = outputs["max_deep_lab"]["semantic_logits"]
textness, unused_logits = tf.split(semantic_logits, [2, -1], -1)
# Channel[0:2], textness. c0: non-textness, c1: textness.
word_score = tf.nn.softmax(textness, -1, "word_score")[:, :, :, 1:2]
word_score = tf.squeeze(tf.image.resize(word_score, (h, w)), -1)
# Channel[2:] not used yet
outputs["segmentation_output"] = {"word_score": word_score}
def _get_instance_outputs(self, outputs: typing.NestedTensorDict,
input_shape: tf.TensorShape):
"""Add `instance_output` to outputs.
Args:
outputs: A dictionary of outputs.
input_shape: The shape of the input images.
These following fields are added to outputs["instance_output"]:
"cls_logits": tf.Tensor, [B, N, C].
"mask_id_logits": tf.Tensor, [B, H, W, N].
"cls_prob": tf.Tensor, [B, N, C], softmax probability.
"mask_id_prob": tf.Tensor, [B, H, W, N], softmax probability. They are
used in training. Masks are all resized to full resolution.
"""
# Get instance_output
h, w = input_shape[1:3]
## Classes
class_logits = outputs["max_deep_lab"]["transformer_class_logits"]
# The MaX-DeepLab repo uses the last logit for void; but we use 0.
# Therefore we shift the logits here.
class_logits = tf.roll(class_logits, shift=1, axis=-1)
class_prob = tf.nn.softmax(class_logits)
## Masks
mask_id_logits = outputs["max_deep_lab"]["pixel_space_mask_logits"]
mask_id_prob = tf.nn.softmax(mask_id_logits)
mask_id_logits = tf.image.resize(mask_id_logits, (h, w))
mask_id_prob = tf.image.resize(mask_id_prob, (h, w))
outputs["instance_output"] = {
"cls_logits": class_logits,
"mask_id_logits": mask_id_logits,
"cls_prob": class_prob,
"mask_id_prob": mask_id_prob,
}
def _postprocess(self, outputs: typing.NestedTensorDict):
"""Post-process (filtering) the outputs.
Args:
outputs: A dictionary of outputs.
These following fields are added to outputs["postprocessed"]:
"classes": A (B,N) integer tensor for the class ids.
"binary_masks": A (B, H, W, N) tensor for the N binarized 0/1 masks. Masks
for void cls are set to zero.
"confidence": A (B, N) float tensor for the confidence of "classes".
"mask_area": A (B, N) float tensor for the area of each mask. They are
used in inference / visualization.
"""
# Get postprocessed outputs
outputs["postprocessed"] = {}
## Masks:
mask_id_prob = outputs["instance_output"]["mask_id_prob"]
mask_max_prob = tf.reduce_max(mask_id_prob, axis=-1, keepdims=True)
thresholded_binary_masks = tf.cast(
tf.math.logical_and(
tf.equal(mask_max_prob, mask_id_prob),
tf.greater_equal(mask_max_prob, self._mask_threshold)), tf.float32)
area = tf.reduce_sum(thresholded_binary_masks, axis=(1, 2)) # (B, N)
## Classification:
cls_prob = outputs["instance_output"]["cls_prob"]
cls_max_prob = tf.reduce_max(cls_prob, axis=-1) # B, N
cls_max_id = tf.cast(tf.argmax(cls_prob, axis=-1), tf.float32) # B, N
## filtering
c = utilities.resolve_shape(cls_prob)[2]
non_void = tf.reduce_all(
tf.stack(
[
tf.greater_equal(area, self._filter_area), # mask large enough.
tf.not_equal(cls_max_id, 0), # class-0 is for non-object.
tf.not_equal(cls_max_id,
c - 1), # class-(c-1) is for background (last).
tf.greater_equal(cls_max_prob,
self._class_threshold) # prob >= thr
],
axis=-1),
axis=-1)
non_void = tf.cast(non_void, tf.float32)
# Storing
outputs["postprocessed"]["classes"] = tf.cast(cls_max_id * non_void,
tf.int32)
b, n = utilities.resolve_shape(non_void)
outputs["postprocessed"]["binary_masks"] = (
thresholded_binary_masks * tf.reshape(non_void, (b, 1, 1, n)))
outputs["postprocessed"]["confidence"] = cls_max_prob
outputs["postprocessed"]["mask_area"] = area
def _coloring(self, masks: tf.Tensor) -> tf.Tensor:
"""Coloring segmentation masks.
Used in visualization.
Args:
masks: A float binary tensor of shape (B, H, W, N), representing `B`
samples, with `N` masks of size `H*W` each. Each of the `N` masks will
be assigned a random color.
Returns:
A (b, h, w, 3) float tensor in [0., 1.] for the coloring result.
"""
b, h, w, n = utilities.resolve_shape(masks)
palette = tf.random.uniform((1, n, 3), 0.5, 1.)
colored = tf.reshape(
tf.matmul(tf.reshape(masks, (b, -1, n)), palette), (b, h, w, 3))
return colored
def visualize(self,
outputs: typing.NestedTensorDict,
labels: Optional[typing.TensorDict] = None):
"""Visualizes the outputs and labels.
Args:
outputs: A dictionary of outputs.
labels: A dictionary of labels.
The following dict is added to outputs["visualization"]: {
"instance": {
"pred": A (B, H, W, 3) tensor for the visualized map in [0,1].
"gt": A (B, H, W, 3) tensor for the visualized map in [0,1], if labels
is present.
"concat": Concatenation of "prediction" and "gt" along width axis, if
labels is present. }
"seg-text": {... Similar to above, but the shape is (B, H, W, 1).} } All
of these tensors have a rank of 4 (B, H, W, C).
"""
outputs["visualization"] = {}
# 1. prediction
# 1.1 instance mask
binary_masks = outputs["postprocessed"]["binary_masks"]
outputs["visualization"]["instance"] = {
"pred": self._coloring(binary_masks),
}
# 1.2 text-seg
outputs["visualization"]["seg-text"] = {
"pred":
tf.expand_dims(outputs["segmentation_output"]["word_score"], -1),
}
# 2. labels
if labels is not None:
# 2.1 instance mask
# (B, N, H, W) -> (B, H, W, N); the first one is bkg so removed.
gt_masks = tf.transpose(labels["instance_labels"]["masks"][:, 1:],
(0, 2, 3, 1))
outputs["visualization"]["instance"]["gt"] = self._coloring(gt_masks)
# 2.2 text-seg
outputs["visualization"]["seg-text"]["gt"] = tf.expand_dims(
labels["segmentation_output"]["gt_word_score"], -1)
# 3. concat
for v in outputs["visualization"].values():
# Resize to make the size align. The prediction always has stride=1
# resolution, so we make gt align with pred instead of vice versa.
v["concat"] = tf.concat(
[v["pred"],
tf.image.resize(v["gt"],
tf.shape(v["pred"])[1:3])],
axis=2)
@tf.function
def serve(self, image_tensor: tf.Tensor) -> typing.NestedTensorDict:
"""Method to be exported for SavedModel.
Args:
image_tensor: A float32 normalized tensor representing an image of shape
[1, height, width, channels].
Returns:
Dict of output:
classes: (B, N) int32 tensor == o["postprocessed"]["classes"]
masks: (B, H, W, N) float32 tensor == o["postprocessed"]["binary_masks"]
groups: (B, N, N) float32 tensor == o["para_affinity"]
confidence: A (B, N) float tensor == o["postprocessed"]["confidence"]
mask_area: A (B, N) float tensor == o["postprocessed"]["mask_area"]
"""
features = {"images": image_tensor}
nn_outputs = self(features, False)
outputs = {
"classes": nn_outputs["postprocessed"]["classes"],
"masks": nn_outputs["postprocessed"]["binary_masks"],
"confidence": nn_outputs["postprocessed"]["confidence"],
"mask_area": nn_outputs["postprocessed"]["mask_area"],
"groups": nn_outputs["para_affinity"],
}
return outputs
@gin.configurable()
def _get_decoder_head(
atrous_rates: Sequence[int] = (6, 12, 18),
pixel_space_dim: int = 128,
pixel_space_intermediate: int = 256,
low_level: Sequence[Dict[str, Union[str, int]]] = ({
"feature_key": "res3",
"channels_project": 64,
}, {
"feature_key": "res2",
"channels_project": 32,
}),
num_classes=3,
aux_sem_intermediate=256,
norm_fn=tf.keras.layers.BatchNormalization,
) -> max_deeplab_head.MaXDeepLab:
"""Get the MaX-DeepLab prediction head.
Args:
atrous_rates: Dilation rate for astrou conv in the semantic head.
pixel_space_dim: The dimension for the final panoptic features.
pixel_space_intermediate: The dimension for the layer before
`pixel_space_dim` (i.e. the separable 5x5 layer).
low_level: A list of dicts for the feature pyramid in forming the semantic
output. Each dict represents one skip-path from the backbone.
num_classes: Number of classes (entities + bkg) including void. For example,
if we only want to detect word, then `num_classes` = 3 (1 for word, 1 for
bkg, and 1 for void).
aux_sem_intermediate: Similar to `pixel_space_intermediate`, but for the
auxiliary semantic output head.
norm_fn: The normalization function used in the head.
Returns:
A MaX-DeepLab decoder head (as a keras layer).
"""
# Initialize the configs.
configs = config_pb2.ModelOptions()
configs.decoder.feature_key = "feature_semantic"
configs.decoder.atrous_rates.extend(atrous_rates)
configs.max_deeplab.pixel_space_head.output_channels = pixel_space_dim
configs.max_deeplab.pixel_space_head.head_channels = pixel_space_intermediate
for low_level_config in low_level:
low_level_ = configs.max_deeplab.auxiliary_low_level.add()
low_level_.feature_key = low_level_config["feature_key"]
low_level_.channels_project = low_level_config["channels_project"]
configs.max_deeplab.auxiliary_semantic_head.output_channels = num_classes
configs.max_deeplab.auxiliary_semantic_head.head_channels = aux_sem_intermediate
return max_deeplab_head.MaXDeepLab(configs.decoder,
configs.max_deeplab, 0, norm_fn)
class PseudoLayer(tf.keras.layers.Layer):
"""Pseudo layer for ablation study.
The `call()` function has the same argument signature as a transformer
encoder stack. `unused_ph1` and `unused_ph2` are place holders for this
purpose. When studying the effectiveness of using transformer as the
grouping branch, we can use this PseudoLayer to replace the transformer to
use as a no-transformer baseline.
To use a single projection layer instead of transformer, simply set `extra_fc`
to True.
"""
def __init__(self, extra_fc: bool):
super().__init__(name="extra_fc")
self._extra_fc = extra_fc
if extra_fc:
self._layer = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.LayerNormalization(),
])
def call(self,
fts: tf.Tensor,
unused_ph1: Optional[tf.Tensor],
unused_ph2: Optional[tf.Tensor],
training: Optional[bool] = None) -> tf.Tensor:
"""See base class."""
if self._extra_fc:
return self._layer(fts, training)
return fts
@gin.configurable()
def _get_embed_head(
dimension=256,
norm_fn=tf.keras.layers.BatchNormalization
) -> Tuple[tf.keras.Sequential, tf.keras.Sequential]:
"""Projection layers to get instance & grouping features."""
instance_head = tf.keras.Sequential([
tf.keras.layers.Dense(dimension, use_bias=False),
norm_fn(),
tf.keras.layers.ReLU(),
])
grouping_head = tf.keras.Sequential([
tf.keras.layers.Dense(dimension, use_bias=False),
norm_fn(),
tf.keras.layers.ReLU(),
])
return instance_head, grouping_head
@gin.configurable()
def _get_para_head(
dimension=128,
num_layer=3,
extra_fc=False) -> Tuple[tf.keras.layers.Layer, tf.keras.layers.Layer]:
"""Get the additional para head.
Args:
dimension: the dimension of the final output.
num_layer: the number of transformer layer.
extra_fc: Whether an extra single fully-connected layer is used, when
num_layer=0.
Returns:
an encoder and a projection layer for the grouping features.
"""
if num_layer > 0:
encoder = transformer.EncoderStack(
params={
"hidden_size": 256,
"num_hidden_layers": num_layer,
"num_heads": 4,
"filter_size": 512,
"initializer_gain": 1.0,
"attention_dropout": 0.1,
"relu_dropout": 0.1,
"layer_postprocess_dropout": 0.1,
"allow_ffn_pad": True,
})
else:
encoder = PseudoLayer(extra_fc)
dense = tf.keras.layers.Dense(dimension)
return encoder, dense
def _dice_sim(pred: tf.Tensor, ground_truth: tf.Tensor) -> tf.Tensor:
"""Dice Coefficient for mask similarity.
Args:
pred: The predicted mask. [B, N, H, W], in [0, 1].
ground_truth: The ground-truth mask. [B, N, H, W], in [0, 1] or {0, 1}.
Returns:
A matrix for the losses: m[b, i, j] is the dice similarity between pred `i`
and gt `j` in batch `b`.
"""
b, n = utilities.resolve_shape(pred)[:2]
ground_truth = tf.reshape(
tf.transpose(ground_truth, (0, 2, 3, 1)), (b, -1, n)) # B, HW, N
pred = tf.reshape(pred, (b, n, -1)) # B, N, HW
numerator = tf.matmul(pred, ground_truth) * 2.
# TODO(longshangbang): The official implementation does not square the scores.
# Need to do experiment to determine which one is better.
denominator = (
tf.math.reduce_sum(tf.math.square(ground_truth), 1, keepdims=True) +
tf.math.reduce_sum(tf.math.square(pred), 2, keepdims=True))
return (numerator + EPSILON) / (denominator + EPSILON)
def _semantic_loss(
loss_dict: Dict[str, tf.Tensor],
labels: tf.Tensor,
outputs: tf.Tensor,
):
"""Auxiliary semantic loss.
Currently, these losses are added:
(1) text/non-text heatmap
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary containing:
`gt_word_score`: (B, H, W) tensor for the text/non-text map.
outputs: The output dictionary containing:
`word_score`: (B, H, W) prediction tensor for `gt_word_score`
"""
pred = tf.expand_dims(outputs["word_score"], 1)
gt = tf.expand_dims(labels["gt_word_score"], 1)
loss_dict["loss_segmentation_word"] = 1. - tf.reduce_mean(_dice_sim(pred, gt))
@gin.configurable
def _entity_mask_loss(loss_dict: Dict[str, tf.Tensor],
labels: tf.Tensor,
outputs: tf.Tensor,
alpha: float = gin.REQUIRED):
"""PQ loss for entity-mask training.
This method adds the PQ loss term to loss_dict directly. The match result will
also be stored in outputs (As a [B, N_pred, N_gt] float tensor).
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: A dict containing: `num_instance` - (B,) `masks` - (B, N, H, W)
`classes` - (B, N)
outputs: A dict containing:
`cls_prob`: (B, N, C)
`mask_id_prob`: (B, H, W, N)
`cls_logits`: (B, N, C)
`mask_id_logits`: (B, H, W, N)
alpha: Weight for pos/neg balance.
"""
# Classification score: (B, N, N)
# in batch b, the probability of prediction i being class of gt j, i.e.:
# score[b, i, j] = pred_cls[b, i, gt_cls[b, j]]
gt_cls = labels["classes"] # (B, N)
pred_cls = outputs["cls_prob"] # (B, N, C)
b, n = utilities.resolve_shape(pred_cls)[:2]
# indices[b, i, j] = gt_cls[b, j]
indices = tf.tile(tf.expand_dims(gt_cls, 1), (1, n, 1))
cls_score = tf.gather(pred_cls, tf.cast(indices, tf.int32), batch_dims=2)
# Mask score (dice): (B, N, N)
# mask_score[b, i, j]: dice-similarity for pred i and gt j in batch b.
mask_score = _dice_sim(
tf.transpose(outputs["mask_id_prob"], (0, 3, 1, 2)), labels["masks"])
# Get similarity matrix and matching.
# padded mask[b, j, i] = -1 << other scores, if i >= num_instance[b]
similarity = cls_score * mask_score
padded_mask = tf.cast(tf.reshape(tf.range(n), (1, 1, n)), tf.float32)
padded_mask = tf.cast(
tf.math.greater_equal(padded_mask,
tf.reshape(labels["num_instance"], (b, 1, 1))),
tf.float32)
# The constant value for padding has no effect.
masked_similarity = similarity * (1. - padded_mask) + padded_mask * (-1.)
matched_mask = matchers_ops.hungarian_matching(-masked_similarity)
matched_mask = tf.cast(matched_mask, tf.float32) * (1 - padded_mask)
outputs["matched_mask"] = matched_mask
# Pos loss
loss_pos = (
tf.stop_gradient(cls_score) * (-mask_score) +
tf.stop_gradient(mask_score) * (-tf.math.log(cls_score)))
loss_pos = tf.reduce_sum(loss_pos * matched_mask, axis=[1, 2]) # (B,)
# Neg loss
matched_pred = tf.cast(tf.reduce_sum(matched_mask, axis=2) > 0,
tf.float32) # (B, N)
# 0 for void class
log_loss = -tf.nn.log_softmax(outputs["cls_logits"])[:, :, 0] # (B, N)
loss_neg = tf.reduce_sum(log_loss * (1. - matched_pred), axis=-1) # (B,)
loss_pq = (alpha * loss_pos + (1 - alpha) * loss_neg) / n
loss_pq = tf.reduce_mean(loss_pq)
loss_dict["loss_pq"] = loss_pq
@gin.configurable
def _instance_discrimination_loss(loss_dict: Dict[str, Any],
labels: Dict[str, Any],
outputs: Dict[str, Any],
tau: float = gin.REQUIRED):
"""Instance discrimination loss.
This method adds the ID loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
tau: The temperature term in the loss
"""
# The normalized feature, shape=(B, H/4, W/4, D)
g = outputs["max_deep_lab"]["pixel_space_normalized_feature"]
b, h, w = utilities.resolve_shape(g)[:3]
# The ground-truth masks, shape=(B, N, H, W) --> (B, N, H/4, W/4)
m = labels["masks"]
m = tf.image.resize(
tf.transpose(m, (0, 2, 3, 1)), (h, w),
tf.image.ResizeMethod.NEAREST_NEIGHBOR)
m = tf.transpose(m, (0, 3, 1, 2))
# The number of ground-truth instance (K), shape=(B,)
num = labels["num_instance"]
n = utilities.resolve_shape(m)[1] # max number of predictions
# is_void[b, i] = 1 if instance i in batch b is a padded slot.
is_void = tf.cast(tf.expand_dims(tf.range(n), 0), tf.float32) # (1, n)
is_void = tf.cast(
tf.math.greater_equal(is_void, tf.expand_dims(num, 1)), tf.float32)
# (B, N, D)
t = tf.math.l2_normalize(tf.einsum("bhwd,bnhw->bnd", g, m), axis=-1)
inst_dist_logits = tf.einsum("bhwd,bid->bhwi", g, t) / tau # (B, H, W, N)
inst_dist_logits = inst_dist_logits - 100. * tf.reshape(is_void, (b, 1, 1, n))
mask_id = tf.cast(
tf.einsum("bnhw,n->bhw", m, tf.range(n, dtype=tf.float32)), tf.int32)
loss_map = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=mask_id, logits=inst_dist_logits) # B, H, W
valid_mask = tf.reduce_sum(m, axis=1)
loss_inst_dist = (
(tf.reduce_sum(loss_map * valid_mask, axis=[1, 2]) + EPSILON) /
(tf.reduce_sum(valid_mask, axis=[1, 2]) + EPSILON))
loss_dict["loss_inst_dist"] = tf.reduce_mean(loss_inst_dist)
@gin.configurable
def _paragraph_grouping_loss(
loss_dict: Dict[str, Any],
labels: Dict[str, Any],
outputs: Dict[str, Any],
tau: float = gin.REQUIRED,
loss_mode="vanilla",
fl_alpha: float = 0.25,
fl_gamma: float = 2.,
):
"""Instance discrimination loss.
This method adds the para discrimination loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
tau: The temperature term in the loss
loss_mode: The type of loss.
fl_alpha: alpha value in focal loss
fl_gamma: gamma value in focal loss
"""
if "paragraph_labels" not in labels:
loss_dict["loss_para"] = 0.
return
# step 1:
# obtain the paragraph labels for each prediction
# (batch, pred, gt)
matched_matrix = outputs["instance_output"]["matched_mask"] # B, N, N
para_label_gt = labels["paragraph_labels"]["paragraph_ids"] # B, N
has_para_label_gt = (
labels["paragraph_labels"]["has_para_ids"][:, tf.newaxis, tf.newaxis])
# '0' means no paragraph labels
pred_label_gt = tf.einsum("bij,bj->bi", matched_matrix,
tf.cast(para_label_gt + 1, tf.float32))
pred_label_gt_pad_col = tf.expand_dims(pred_label_gt, -1) # b,n,1
pred_label_gt_pad_row = tf.expand_dims(pred_label_gt, 1) # b,1,n
gt_affinity = tf.cast(
tf.equal(pred_label_gt_pad_col, pred_label_gt_pad_row), tf.float32)
gt_affinity_mask = (
has_para_label_gt * pred_label_gt_pad_col * pred_label_gt_pad_row)
gt_affinity_mask = tf.cast(tf.not_equal(gt_affinity_mask, 0.), tf.float32)
# step 2:
# get affinity matrix
affinity = outputs["para_affinity"]
# step 3:
# compute loss
loss_fn = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
label_smoothing=0,
axis=-1,
reduction=tf.keras.losses.Reduction.NONE,
name="para_dist")
affinity = tf.reshape(affinity, (-1, 1)) # (b*n*n, 1)
gt_affinity = tf.reshape(gt_affinity, (-1, 1)) # (b*n*n, 1)
gt_affinity_mask = tf.reshape(gt_affinity_mask, (-1,)) # (b*n*n,)
pointwise_loss = loss_fn(gt_affinity, affinity / tau) # (b*n*n,)
if loss_mode == "vanilla":
loss = (
tf.reduce_sum(pointwise_loss * gt_affinity_mask) /
(tf.reduce_sum(gt_affinity_mask) + EPSILON))
elif loss_mode == "balanced":
# pos
pos_mask = gt_affinity_mask * gt_affinity[:, 0]
pos_loss = (
tf.reduce_sum(pointwise_loss * pos_mask) /
(tf.reduce_sum(pos_mask) + EPSILON))
# neg
neg_mask = gt_affinity_mask * (1. - gt_affinity[:, 0])
neg_loss = (
tf.reduce_sum(pointwise_loss * neg_mask) /
(tf.reduce_sum(neg_mask) + EPSILON))
loss = 0.25 * pos_loss + 0.75 * neg_loss
elif loss_mode == "focal":
alpha_wt = fl_alpha * gt_affinity + (1. - fl_alpha) * (1. - gt_affinity)
prob_pos = tf.math.sigmoid(affinity / tau)
pt = prob_pos * gt_affinity + (1. - prob_pos) * (1. - gt_affinity)
fl_loss_pw = tf.stop_gradient(
alpha_wt * tf.pow(1. - pt, fl_gamma))[:, 0] * pointwise_loss
loss = (
tf.reduce_sum(fl_loss_pw * gt_affinity_mask) /
(tf.reduce_sum(gt_affinity_mask) + EPSILON))
else:
raise ValueError(f"Not supported loss mode: {loss_mode}")
loss_dict["loss_para"] = loss
def _mask_id_xent_loss(loss_dict: Dict[str, Any], labels: Dict[str, Any],
outputs: Dict[str, Any]):
"""Mask ID loss.
This method adds the mask ID loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
"""
# (B, N, H, W)
mask_gt = labels["masks"]
# B, H, W, N
mask_id_logits = outputs["instance_output"]["mask_id_logits"]
# B, N, N
matched_matrix = outputs["instance_output"]["matched_mask"]
# B, N
gt_to_pred_id = tf.cast(tf.math.argmax(matched_matrix, axis=1), tf.float32)
# B, H, W
mask_id_labels = tf.cast(
tf.einsum("bnhw,bn->bhw", mask_gt, gt_to_pred_id), tf.int32)
loss_map = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=mask_id_labels, logits=mask_id_logits)
valid_mask = tf.reduce_sum(mask_gt, axis=1)
loss_mask_id = (
(tf.reduce_sum(loss_map * valid_mask, axis=[1, 2]) + EPSILON) /
(tf.reduce_sum(valid_mask, axis=[1, 2]) + EPSILON))
loss_dict["loss_mask_id"] = tf.reduce_mean(loss_mask_id)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.projects.unified_detector import external_configurables
from official.projects.unified_detector.configs import ocr_config
from official.projects.unified_detector.tasks import ocr_task
from official.vision import registry_imports
tf-nightly
gin-config
opencv-python==4.1.2.30
absl-py>=1.0.0
shapely>=1.8.1
apache_beam>=2.37.0
matplotlib>=3.5.1
notebook>=6.4.10
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""A binary to run unified detector."""
import json
import os
from typing import Any, Dict, Sequence, Union
from absl import app
from absl import flags
from absl import logging
import cv2
import gin
import numpy as np
import tensorflow as tf
import tqdm
from official.projects.unified_detector import external_configurables # pylint: disable=unused-import
from official.projects.unified_detector.modeling import universal_detector
from official.projects.unified_detector.utils import utilities
# group two lines into a paragraph if affinity score higher than this
_PARA_GROUP_THR = 0.5
# MODEL spec
_GIN_FILE = flags.DEFINE_string(
'gin_file', None, 'Path to the Gin file that defines the model.')
_CKPT_PATH = flags.DEFINE_string(
'ckpt_path', None, 'Path to the checkpoint directory.')
_IMG_SIZE = flags.DEFINE_integer(
'img_size', 1024, 'Size of the image fed to the model.')
# Input & Output
# Note that, all images specified by `img_file` and `img_dir` will be processed.
_IMG_FILE = flags.DEFINE_multi_string('img_file', [], 'Paths to the images.')
_IMG_DIR = flags.DEFINE_multi_string(
'img_dir', [], 'Paths to the image directories.')
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Path for the output.')
_VIS_DIR = flags.DEFINE_string(
'vis_dir', None, 'Path for the visualization output.')
def _preprocess(raw_image: np.ndarray) -> Union[np.ndarray, float]:
"""Convert a raw image to properly resized, padded, and normalized ndarray."""
# (1) convert to tf.Tensor and float32.
img_tensor = tf.convert_to_tensor(raw_image, dtype=tf.float32)
# (2) pad to square.
height, width = img_tensor.shape[:2]
maximum_side = tf.maximum(height, width)
height_pad = maximum_side - height
width_pad = maximum_side - width
img_tensor = tf.pad(
img_tensor, [[0, height_pad], [0, width_pad], [0, 0]],
constant_values=127)
ratio = maximum_side / _IMG_SIZE.value
# (3) resize long side to the maximum length.
img_tensor = tf.image.resize(
img_tensor, (_IMG_SIZE.value, _IMG_SIZE.value))
img_tensor = tf.cast(img_tensor, tf.uint8)
# (4) normalize
img_tensor = utilities.normalize_image_to_range(img_tensor)
# (5) Add batch dimension and return as numpy array.
return tf.expand_dims(img_tensor, 0).numpy(), float(ratio)
def load_model() -> tf.keras.layers.Layer:
gin.parse_config_file(_GIN_FILE.value)
model = universal_detector.UniversalDetector()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = _CKPT_PATH.value
logging.info('Load ckpt from: %s', ckpt_path)
ckpt.restore(ckpt_path).expect_partial()
return model
def inference(img_file: str, model: tf.keras.layers.Layer) -> Dict[str, Any]:
"""Inference step."""
img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
img_ndarray, ratio = _preprocess(img)
output_dict = model.serve(img_ndarray)
class_tensor = output_dict['classes'].numpy()
mask_tensor = output_dict['masks'].numpy()
group_tensor = output_dict['groups'].numpy()
indices = np.where(class_tensor[0])[0].tolist() # indices of positive slots.
mask_list = [
mask_tensor[0, :, :, index] for index in indices] # List of mask ndarray.
# Form lines and words
lines = []
line_indices = []
for index, mask in tqdm.tqdm(zip(indices, mask_list)):
line = {
'words': [],
'text': '',
}
contours, _ = cv2.findContours(
(mask > 0.).astype(np.uint8),
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)[-2:]
for contour in contours:
if (isinstance(contour, np.ndarray) and
len(contour.shape) == 3 and
contour.shape[0] > 2 and
contour.shape[1] == 1 and
contour.shape[2] == 2):
cnt_list = (contour[:, 0] * ratio).astype(np.int32).tolist()
line['words'].append({'text': '', 'vertices': cnt_list})
else:
logging.error('Invalid contour: %s, discarded', str(contour))
if line['words']:
lines.append(line)
line_indices.append(index)
# Form paragraphs
line_grouping = utilities.DisjointSet(len(line_indices))
affinity = group_tensor[0][line_indices][:, line_indices]
for i1, i2 in zip(*np.where(affinity > _PARA_GROUP_THR)):
line_grouping.union(i1, i2)
line_groups = line_grouping.to_group()
paragraphs = []
for line_group in line_groups:
paragraph = {'lines': []}
for id_ in line_group:
paragraph['lines'].append(lines[id_])
if paragraph:
paragraphs.append(paragraph)
return paragraphs
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Get list of images
img_lists = []
img_lists.extend(_IMG_FILE.value)
for img_dir in _IMG_DIR.value:
img_lists.extend(tf.io.gfile.glob(os.path.join(img_dir, '*')))
logging.info('Total number of input images: %d', len(img_lists))
model = load_model()
vis_dis = _VIS_DIR.value
output = {'annotations': []}
for img_file in tqdm.tqdm(img_lists):
output['annotations'].append({
'image_id': img_file.split('/')[-1].split('.')[0],
'paragraphs': inference(img_file, model),
})
if vis_dis:
key = output['annotations'][-1]['image_id']
paragraphs = output['annotations'][-1]['paragraphs']
img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
word_bnds = []
line_bnds = []
para_bnds = []
for paragraph in paragraphs:
paragraph_points_list = []
for line in paragraph['lines']:
line_points_list = []
for word in line['words']:
word_bnds.append(
np.array(word['vertices'], np.int32).reshape((-1, 1, 2)))
line_points_list.extend(word['vertices'])
paragraph_points_list.extend(line_points_list)
line_points = np.array(line_points_list, np.int32) # (N,2)
left = int(np.min(line_points[:, 0]))
top = int(np.min(line_points[:, 1]))
right = int(np.max(line_points[:, 0]))
bottom = int(np.max(line_points[:, 1]))
line_bnds.append(
np.array([[[left, top]], [[right, top]], [[right, bottom]],
[[left, bottom]]], np.int32))
para_points = np.array(paragraph_points_list, np.int32) # (N,2)
left = int(np.min(para_points[:, 0]))
top = int(np.min(para_points[:, 1]))
right = int(np.max(para_points[:, 0]))
bottom = int(np.max(para_points[:, 1]))
para_bnds.append(
np.array([[[left, top]], [[right, top]], [[right, bottom]],
[[left, bottom]]], np.int32))
for name, bnds in zip(['paragraph', 'line', 'word'],
[para_bnds, line_bnds, word_bnds]):
vis = cv2.polylines(img, bnds, True, (0, 0, 255), 2)
cv2.imwrite(os.path.join(vis_dis, f'{key}-{name}.jpg'),
cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))
with tf.io.gfile.GFile(_OUTPUT_PATH.value, mode='w') as f:
f.write(json.dumps(output, ensure_ascii=False, indent=2))
if __name__ == '__main__':
flags.mark_flags_as_required(['gin_file', 'ckpt_path', 'output_path'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import all models.
All model files are imported here so that they can be referenced in Gin. Also,
importing here avoids making ocr_task.py too messy.
"""
# pylint: disable=unused-import
from official.projects.unified_detector.modeling import universal_detector
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for ocr."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import gin
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.projects.unified_detector.configs import ocr_config
from official.projects.unified_detector.data_loaders import input_reader
from official.projects.unified_detector.tasks import all_models # pylint: disable=unused-import
from official.projects.unified_detector.utils import typing
NestedTensorDict = typing.NestedTensorDict
ModelType = Union[tf.keras.layers.Layer, tf.keras.Model]
@task_factory.register_task_cls(ocr_config.OcrTaskConfig)
@gin.configurable
class OcrTask(base_task.Task):
"""Defining the OCR training task."""
_loss_items = []
def __init__(self,
params: cfg.TaskConfig,
logging_dir: Optional[str] = None,
name: Optional[str] = None,
model_fn: Callable[..., ModelType] = gin.REQUIRED):
super().__init__(params, logging_dir, name)
self._modef_fn = model_fn
def build_model(self) -> ModelType:
"""Build and return the model, record the loss items as well."""
model = self._modef_fn()
self._loss_items.extend(model.loss_items)
return model
def build_inputs(
self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Build the tf.data.Dataset instance."""
return input_reader.InputFn(is_training=params.is_training)({},
input_context)
def build_metrics(self,
training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
"""Build the metrics (currently, only for loss summaries in TensorBoard)."""
del training
metrics = []
# Add loss items
for name in self._loss_items:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
# TODO(longshangbang): add evaluation metrics
return metrics
def train_step(
self,
inputs: Tuple[NestedTensorDict, NestedTensorDict],
model: ModelType,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[str, tf.Tensor]:
features, labels = inputs
input_dict = {"features": features}
if self.task_config.model_call_needs_labels:
input_dict["labels"] = labels
is_mixed_precision = isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer)
with tf.GradientTape() as tape:
outputs = model(**input_dict, training=True)
loss, loss_dict = model.compute_losses(labels=labels, outputs=outputs)
loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
if is_mixed_precision:
loss = optimizer.get_scaled_loss(loss)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
if is_mixed_precision:
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {"loss": loss}
if metrics:
for m in metrics:
m.update_state(loss_dict[m.name])
return logs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.projects.unified_detector import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Typing extension."""
from typing import Dict, Union
import numpy as np
import tensorflow as tf
NpDict = Dict[str, np.ndarray]
FeaturesAndLabelsType = Dict[str, Dict[str, tf.Tensor]]
TensorDict = Dict[Union[str, int], tf.Tensor]
NestedTensorDict = Dict[
Union[str, int],
Union[tf.Tensor,
TensorDict]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment