Commit ac1f5735 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 452578619
parent 52902342
# Towards End-to-End Unified Scene Text Detection and Layout Analysis
![unified detection](docs/images/task.png)
[![UnifiedDetector](https://img.shields.io/badge/UnifiedDetector-arxiv.2203.15143-green)](https://arxiv.org/abs/2203.15143)
Official TensorFlow 2 implementation of the paper `Towards End-to-End Unified
Scene Text Detection and Layout Analysis`. If you encounter any issues using the
code, you are welcome to submit them to the Issues tab or send emails directly
to us: `hiertext@google.com`.
## Installation
### Set up TensorFlow Models
```bash
# (Optional) Create and enter a virtual environment
pip3 install --user virtualenv
virtualenv -p python3 unified_detector
source ./unified_detector/bin/activate
# First clone the TensorFlow Models project:
git clone https://github.com/tensorflow/models.git
# Install the requirements of TensorFlow Models and this repo:
cd models
pip3 install -r official/requirements.txt
pip3 install -r official/projects/unified_detector/requirements.txt
# Compile the protos
# If `protoc` is not installed, please follow: https://grpc.io/docs/protoc-installation/
export PYTHONPATH=${PYTHONPATH}:${PWD}/research/
cd research/object_detection/
protoc protos/string_int_label_map.proto --python_out=.
```
### Set up Deeplab2
```bash
# Clone Deeplab2 anywhere you like
cd <somewhere>
git clone https://github.com/google-research/deeplab2.git
# Compile the protos
protoc deeplab2/*.proto --python_out=.
# Add to PYTHONPATH the directory where deeplab2 sits.
export PYTHONPATH=${PYTHONPATH}:${PWD}
```
## Running the model on some images using the provided checkpoint.
### Download the checkpoint
Model | Input Resolution | #object query | line PQ (val) | paragraph PQ (val) | line PQ (test) | paragraph PQ (test)
---------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ------------- | ------------- | ------------------ | -------------- | -------------------
Unified-Detector-Line ([ckpt](https://storage.cloud.google.com/tf_model_garden/vision/unified_detector/unified_detector_ckpt.tgz)) | 1024 | 384 | 61.04 | 52.84 | 62.20 | 53.52
### Demo on single images
```bash
# run from `models/`
python3 -m official.projects.unified_detector.run_inference \
--gin_file=official/projects/unified_detector/configs/gin_files/unified_detector_model.gin \
--ckpt_path=<path-of-the-ckpt> \
--img_file=<some-image> \
--output_path=<some-directory>/demo.jsonl \
--vis_dir=<some-directory>
```
The output will be stored in jsonl in the same hierarchical format as required
by the evaluation script of the HierText dataset. There will also be
visualizations of the word/line/paragraph boundaries. Note that, the unified
detector produces line-level masks and an affinity matrix for grouping lines
into paragraphs. For visualization purpose, we split each line mask into pixel
groups which are defined as connected components/pixels. We visualize these
groups as `words`. They are not necessarily at the word granularity, though. We
visualize lines and paragraphs as groupings of these `words` using axis-aligned
bounding boxes.
## Inference and Evaluation on the HierText dataset
### Download the HierText dataset
Clone the [HierText repo](https://github.com/google-research-datasets/hiertext)
and download the dataset. The `requirements.txt` in this folder already covers
those in the HierText repo, so there is no need to create a new virtual
environment again.
### Inference and eval
The following command will run the model on the validation set and compute the
score. Note that the test set annotation is not released yet, so only validation
set is used here for demo purposes.
#### Inference
```bash
# Run from `models/`
python3 -m official.projects.unified_detector.run_inference \
--gin_file=official/projects/unified_detector/configs/gin_files/unified_detector_model.gin \
--ckpt_path=<path-of-the-ckpt> \
--img_dir=<the-directory-containing-validation-images> \
--output_path=<some-directory>/validation_output.jsonl
```
#### Evaluation
```bash
# Run from `hiertext/`
python3 eval.py \
--gt=gt/validation.jsonl \
--result=<some-directory>/validation_output.jsonl \
--output=./validation-score.txt \
--mask_stride=1 \
--eval_lines \
--eval_paragraphs \
--num_workers=0
```
## Train new models.
First, you will need to convert the HierText dataset into TFrecords:
```bash
# Run from `models/official/projects/unified_detector/data_conversion`
CUDA_VISIBLE_DEVICES='' python3 convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
```
To train the unified detector, run the following script:
```bash
# Run from `models/`
python3 -m official.projects.unified_detector.train \
--mode=train \
--experiment=unified_detector \
--model_dir='<some path>' \
--gin_file='official/projects/unified_detector/configs/gin_files/unified_detector_train.gin' \
--gin_file='official/projects/unified_detector/configs/gin_files/unified_detector_model.gin' \
--gin_params='InputFn.input_paths = ["/path/to/tfrecords/file-prefix*"]'
```
## Citation
Please cite our [paper](https://arxiv.org/pdf/2203.15143.pdf) if you find this
work helpful:
```
@inproceedings{long2022towards,
title={Towards End-to-End Unified Scene Text Detection and Layout Analysis},
author={Long, Shangbang and Qin, Siyang and Panteleev, Dmitry and Bissacco, Alessandro and Fujii, Yasuhisa and Raptis, Michalis},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
```
# Defining the unified detector models.
# Model
## Backbone
num_slots = 384
SyncBatchNormalization.momentum = 0.95
get_max_deep_lab_backbone.num_slots = %num_slots
## Decoder
intermediate_filters = 256
num_entity_class = 3 # C + 1 (bkg) + 1 (void)
_get_decoder_head.atrous_rates = (6, 12, 18)
_get_decoder_head.pixel_space_dim = 128
_get_decoder_head.pixel_space_intermediate = %intermediate_filters
_get_decoder_head.num_classes = %num_entity_class
_get_decoder_head.aux_sem_intermediate = %intermediate_filters
_get_decoder_head.low_level = [
{'feature_key': 'res3', 'channels_project': 64,},
{'feature_key': 'res2', 'channels_project': 32,},]
_get_decoder_head.norm_fn = @SyncBatchNormalization
_get_embed_head.norm_fn = @LayerNorm
# Loss
# pq loss
alpha = 0.75
tau = 0.3
_entity_mask_loss.alpha = %alpha
_instance_discrimination_loss.tau = %tau
_paragraph_grouping_loss.tau = %tau
_paragraph_grouping_loss.loss_mode = 'balanced'
# Other Model setting
UniversalDetector.mask_threshold = 0.4
UniversalDetector.class_threshold = 0.5
UniversalDetector.filter_area = 32
universal_detection_loss_weights.loss_segmentation_word = 1e0
universal_detection_loss_weights.loss_inst_dist = 1e0
universal_detection_loss_weights.loss_mask_id = 1e-4
universal_detection_loss_weights.loss_pq = 3e0
universal_detection_loss_weights.loss_para = 1e0
# Defining the input pipeline of unified detector.
# ===== ===== Model ===== =====
# Internal import 2.
OcrTask.model_fn = @UniversalDetector
# ===== ===== Data pipeline ===== =====
InputFn.parser_fn = @UniDetectorParserFn
InputFn.dataset_type = 'tfrecord'
InputFn.batch_size = 256
# Internal import 3.
UniDetectorParserFn.output_dimension = 1024
# Simple data augmentation for now.
UniDetectorParserFn.rot90_probability = 0.0
UniDetectorParserFn.use_color_distortion = True
UniDetectorParserFn.crop_min_scale = 0.5
UniDetectorParserFn.crop_max_scale = 1.5
UniDetectorParserFn.crop_min_aspect = 0.8
UniDetectorParserFn.crop_max_aspect = 1.25
UniDetectorParserFn.max_num_instance = 384
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OCR tasks and models configurations."""
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
@dataclasses.dataclass
class OcrTaskConfig(cfg.TaskConfig):
train_data: cfg.DataConfig = cfg.DataConfig()
model_call_needs_labels: bool = False
@exp_factory.register_config_factory('unified_detector')
def unified_detector() -> cfg.ExperimentConfig:
"""Configurations for trainer of unified detector."""
total_train_steps = 100000
summary_interval = steps_per_loop = 200
checkpoint_interval = 2000
warmup_steps = 1000
config = cfg.ExperimentConfig(
# Input pipeline and model are configured through Gin.
task=OcrTaskConfig(train_data=cfg.DataConfig(is_training=True)),
trainer=cfg.TrainerConfig(
train_steps=total_train_steps,
steps_per_loop=steps_per_loop,
summary_interval=summary_interval,
checkpoint_interval=checkpoint_interval,
max_to_keep=1,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': [
'^((?!depthwise).)*(kernel|weights):0$',
],
'exclude_from_weight_decay': [
'(^((?!kernel).)*:0)|(depthwise_kernel)',
],
'gradient_clip_norm': 10.,
},
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 1e-3,
'decay_steps': total_train_steps - warmup_steps,
'alpha': 1e-2,
'offset': warmup_steps,
},
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_learning_rate': 1e-5,
'warmup_steps': warmup_steps,
}
},
}),
),
)
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script to convert HierText to TFExamples.
This script is only intended to run locally.
python3 data_preprocess/convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
"""
import json
import os
import random
from absl import app
from absl import flags
import tensorflow as tf
import tqdm
import utils
_GT_FILE = flags.DEFINE_string('gt_file', None, 'Path to the GT file')
_IMG_DIR = flags.DEFINE_string('img_dir', None, 'Path to the image folder.')
_OUT_FILE = flags.DEFINE_string('out_file', None, 'Path for the tfrecords.')
_NUM_SHARD = flags.DEFINE_integer(
'num_shard', 100, 'The number of shards of tfrecords.')
def main(unused_argv) -> None:
annotations = json.load(open(_GT_FILE.value))['annotations']
random.shuffle(annotations)
n_sample = len(annotations)
n_shards = _NUM_SHARD.value
n_sample_per_shard = (n_sample - 1) // n_shards + 1
for shard in tqdm.tqdm(range(n_shards)):
output_path = f'{_OUT_FILE.value}-{shard:05}-{n_shards:05}.tfrecords'
annotation_subset = annotations[
shard * n_sample_per_shard : (shard + 1) * n_sample_per_shard]
with tf.io.TFRecordWriter(output_path) as file_writer:
for annotation in annotation_subset:
img_file_path = os.path.join(_IMG_DIR.value,
f"{annotation['image_id']}.jpg")
tfexample = utils.convert_to_tfe(img_file_path, annotation)
file_writer.write(tfexample)
if __name__ == '__main__':
flags.mark_flags_as_required(['gt_file', 'img_dir', 'out_file'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to convert data to TFExamples and store in TFRecords."""
from typing import Any, Dict, List, Tuple, Union
import cv2
import numpy as np
import tensorflow as tf
def encode_image(
image_tensor: np.ndarray,
encoding_type: str = 'png') -> Union[np.ndarray, tf.Tensor]:
"""Encode image tensor into byte string."""
if encoding_type == 'jpg':
image_encoded = tf.image.encode_jpeg(tf.constant(image_tensor))
elif encoding_type == 'png':
image_encoded = tf.image.encode_png(tf.constant(image_tensor))
else:
raise ValueError('Invalid encoding type.')
if tf.executing_eagerly():
image_encoded = image_encoded.numpy()
else:
image_encoded = image_encoded.eval()
return image_encoded
def int64_feature(value: Union[int, List[int]]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value: Union[float, List[float]]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value: Union[Union[bytes, str], List[Union[bytes, str]]]
) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
for i in range(len(value)):
if not isinstance(value[i], bytes):
value[i] = value[i].encode('utf-8')
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def annotation_to_entities(annotation: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Flatten the annotation dict to a list of 'entities'."""
entities = []
for paragraph in annotation['paragraphs']:
paragraph_id = len(entities)
paragraph['type'] = 3 # 3 for paragraph
paragraph['parent_id'] = -1
entities.append(paragraph)
for line in paragraph['lines']:
line_id = len(entities)
line['type'] = 2 # 2 for line
line['parent_id'] = paragraph_id
entities.append(line)
for word in line['words']:
word['type'] = 1 # 1 for word
word['parent_id'] = line_id
entities.append(word)
return entities
def draw_entity_mask(
entities: List[Dict[str, Any]],
image_shape: Tuple[int, int, int]) -> np.ndarray:
"""Draw entity id mask.
Args:
entities: A list of entity objects. Should be output from
`annotation_to_entities`.
image_shape: The shape of the input image.
Returns:
A (H, W, 3) entity id mask of the same height/width as the image. Each pixel
(i, j, :) encodes the entity id of one pixel. Only word entities are
rendered. 0 for non-text pixels; word entity ids start from 1.
"""
instance_mask = np.zeros(image_shape, dtype=np.uint8)
for i, entity in enumerate(entities):
# only draw word masks
if entity['type'] != 1:
continue
vertices = np.array(entity['vertices'])
# the pixel value is actually 1 + position in entities
entity_id = i + 1
if entity_id >= 65536:
# As entity_id is encoded in the last two channels, it should be less than
# 256**2=65536.
raise ValueError(
(f'Entity ID overflow: {entity_id}. Currently only entity_id<65536 '
'are supported.'))
# use the last two channels to encode the entity id.
color = [0, entity_id // 256, entity_id % 256]
instance_mask = cv2.fillPoly(instance_mask,
[np.round(vertices).astype('int32')], color)
return instance_mask
def convert_to_tfe(img_file_name: str,
annotation: Dict[str, Any]) -> tf.train.Example:
"""Convert the annotation dict into a TFExample."""
img = cv2.imread(img_file_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, c = img.shape
encoded_img = encode_image(img)
entities = annotation_to_entities(annotation)
masks = draw_entity_mask(entities, img.shape)
encoded_mask = encode_image(masks)
# encode attributes
parent = []
classes = []
content_type = []
text = []
vertices = []
for entity in entities:
parent.append(entity['parent_id'])
classes.append(entity['type'])
# 0 for annotated; 8 for not annotated
content_type.append((0 if entity['legible'] else 8))
text.append(entity.get('text', ''))
v = np.array(entity['vertices'])
vertices.append(','.join(str(float(n)) for n in v.reshape(-1)))
example = tf.train.Example(
features=tf.train.Features(
feature={
# input images
'image/encoded': bytes_feature(encoded_img),
# image format
'image/format': bytes_feature('png'),
# image width
'image/width': int64_feature([w]),
# image height
'image/height': int64_feature([h]),
# image channels
'image/channels': int64_feature([c]),
# image key
'image/source_id': bytes_feature(annotation['image_id']),
# HxWx3 tensors: channel 2-3 encodes the id of the word entity.
'image/additional_channels/encoded': bytes_feature(encoded_mask),
# format of the additional channels
'image/additional_channels/format': bytes_feature('png'),
'image/object/parent': int64_feature(parent),
# word / line / paragraph / symbol / ...
'image/object/classes': int64_feature(classes),
# text / handwritten / not-annotated / ...
'image/object/content_type': int64_feature(content_type),
# string text transcription
'image/object/text': bytes_feature(text),
# comma separated coordinates, (x,y) * n
'image/object/vertices': bytes_feature(vertices),
})).SerializeToString()
return example
This diff is collapsed.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input data reader.
Creates a tf.data.Dataset object from multiple input sstables and use a
provided data parser function to decode the serialized tf.Example and optionally
run data augmentation.
"""
import os
from typing import Any, Callable, List, Optional, Sequence, Union
import gin
from six.moves import map
import tensorflow as tf
from official.common import dataset_fn
from research.object_detection.utils import label_map_util
from official.core import config_definitions as cfg
from official.projects.unified_detector.data_loaders import universal_detection_parser # pylint: disable=unused-import
FuncType = Callable[..., Any]
@gin.configurable(denylist=['is_training'])
class InputFn(object):
"""Input data reader class.
Creates a tf.data.Dataset object from multiple datasets (optionally performs
weighted sampling between different datasets), parses the tf.Example message
using `parser_fn`. The datasets can either be stored in SSTable or TfRecord.
"""
def __init__(self,
is_training: bool,
batch_size: Optional[int] = None,
data_root: str = '',
input_paths: List[str] = gin.REQUIRED,
dataset_type: str = 'tfrecord',
use_sampling: bool = False,
sampling_weights: Optional[Sequence[Union[int, float]]] = None,
cycle_length: Optional[int] = 64,
shuffle_buffer_size: Optional[int] = 512,
parser_fn: Optional[FuncType] = None,
parser_num_parallel_calls: Optional[int] = 64,
max_intra_op_parallelism: Optional[int] = None,
label_map_proto_path: Optional[str] = None,
input_filter_fns: Optional[List[FuncType]] = None,
input_training_filter_fns: Optional[Sequence[FuncType]] = None,
dense_to_ragged_batch: bool = False,
data_validator_fn: Optional[Callable[[Sequence[str]],
None]] = None):
"""Input reader constructor.
Args:
is_training: Boolean indicating TRAIN or EVAL.
batch_size: Input data batch size. Ignored if batch size is passed through
params. In that case, this can be None.
data_root: All the relative input paths are based on this location.
input_paths: Input file patterns.
dataset_type: Can be 'sstable' or 'tfrecord'.
use_sampling: Whether to perform weighted sampling between different
datasets.
sampling_weights: Unnormalized sampling weights. The length should be
equal to `input_paths`.
cycle_length: The number of input Datasets to interleave from in parallel.
If set to None tf.data experimental autotuning is used.
shuffle_buffer_size: The random shuffle buffer size.
parser_fn: The function to run decoding and data augmentation. The
function takes `is_training` as an input, which is passed from here.
parser_num_parallel_calls: The number of parallel calls for `parser_fn`.
The number of CPU cores is the suggested value. If set to None tf.data
experimental autotuning is used.
max_intra_op_parallelism: if set limits the max intra op parallelism of
functions run on slices of the input.
label_map_proto_path: Path to a StringIntLabelMap which will be used to
decode the input data.
input_filter_fns: A list of functions on the dataset points which returns
true for valid data.
input_training_filter_fns: A list of functions on the dataset points which
returns true for valid data used only for training.
dense_to_ragged_batch: Whether to use ragged batching for MPNN format.
data_validator_fn: If not None, used to validate the data specified by
input_paths.
Raises:
ValueError for invalid input_paths.
"""
self._is_training = is_training
if data_root:
# If an input path is absolute this does not change it.
input_paths = [os.path.join(data_root, value) for value in input_paths]
self._input_paths = input_paths
# Disables datasets sampling during eval.
self._batch_size = batch_size
if is_training:
self._use_sampling = use_sampling
else:
self._use_sampling = False
self._sampling_weights = sampling_weights
self._cycle_length = (cycle_length if cycle_length else tf.data.AUTOTUNE)
self._shuffle_buffer_size = shuffle_buffer_size
self._parser_num_parallel_calls = (
parser_num_parallel_calls
if parser_num_parallel_calls else tf.data.AUTOTUNE)
self._max_intra_op_parallelism = max_intra_op_parallelism
self._label_map_proto_path = label_map_proto_path
if label_map_proto_path:
name_to_id = label_map_util.get_label_map_dict(label_map_proto_path)
self._lookup_str_keys = list(name_to_id.keys())
self._lookup_int_values = list(name_to_id.values())
self._parser_fn = parser_fn
self._input_filter_fns = input_filter_fns or []
if is_training and input_training_filter_fns:
self._input_filter_fns.extend(input_training_filter_fns)
self._dataset_type = dataset_type
self._dense_to_ragged_batch = dense_to_ragged_batch
if data_validator_fn is not None:
data_validator_fn(self._input_paths)
@property
def batch_size(self):
return self._batch_size
def __call__(
self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Read and parse input datasets, return a tf.data.Dataset object."""
# TPUEstimator passes the batch size through params.
if params is not None and 'batch_size' in params:
batch_size = params['batch_size']
else:
batch_size = self._batch_size
per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size
with tf.name_scope('input_reader'):
dataset = self._build_dataset_from_records()
dataset_parser_fn = self._build_dataset_parser_fn()
dataset = dataset.map(
dataset_parser_fn, num_parallel_calls=self._parser_num_parallel_calls)
for filter_fn in self._input_filter_fns:
dataset = dataset.filter(filter_fn)
if self._dense_to_ragged_batch:
dataset = dataset.apply(
tf.data.experimental.dense_to_ragged_batch(
batch_size=per_replica_batch_size, drop_remainder=True))
else:
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
def _fetch_dataset(self, filename: str) -> tf.data.Dataset:
"""Fetch dataset depending on type.
Args:
filename: Location of dataset.
Returns:
Tf Dataset.
"""
data_cls = dataset_fn.pick_dataset_fn(self._dataset_type)
data = data_cls([filename])
return data
def _build_dataset_parser_fn(self) -> Callable[..., tf.Tensor]:
"""Depending on label_map and storage type, build a parser_fn."""
# Parse the fetched records to input tensors for model function.
if self._label_map_proto_path:
lookup_initializer = tf.lookup.KeyValueTensorInitializer(
keys=tf.constant(self._lookup_str_keys, dtype=tf.string),
values=tf.constant(self._lookup_int_values, dtype=tf.int32))
name_to_id_table = tf.lookup.StaticHashTable(
initializer=lookup_initializer, default_value=0)
parser_fn = self._parser_fn(
is_training=self._is_training, label_lookup_table=name_to_id_table)
else:
parser_fn = self._parser_fn(is_training=self._is_training)
return parser_fn
def _build_dataset_from_records(self) -> tf.data.Dataset:
"""Build a tf.data.Dataset object from input SSTables.
If the input data come from multiple SSTables, use the user defined sampling
weights to perform sampling. For example, if the sampling weights is
[1., 2.], the second dataset will be sampled twice more often than the first
one.
Returns:
Dataset built from SSTables.
Raises:
ValueError for inability to find SSTable files.
"""
all_file_patterns = []
if self._use_sampling:
for file_pattern in self._input_paths:
all_file_patterns.append([file_pattern])
# Normalize sampling probabilities.
total_weight = sum(self._sampling_weights)
sampling_probabilities = [
float(w) / total_weight for w in self._sampling_weights
]
else:
all_file_patterns.append(self._input_paths)
datasets = []
for file_pattern in all_file_patterns:
filenames = sum(list(map(tf.io.gfile.glob, file_pattern)), [])
if not filenames:
raise ValueError(
f'Error trying to read input files for file pattern {file_pattern}')
# Create a dataset of filenames and shuffle the files. In each epoch,
# the file order is shuffled again. This may help if
# per_host_input_for_training = false on TPU.
dataset = tf.data.Dataset.list_files(
file_pattern, shuffle=self._is_training)
if self._is_training:
dataset = dataset.repeat()
if self._max_intra_op_parallelism:
# Disable intra-op parallelism to optimize for throughput instead of
# latency.
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)
dataset = dataset.interleave(
self._fetch_dataset,
cycle_length=self._cycle_length,
num_parallel_calls=self._cycle_length,
deterministic=(not self._is_training))
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
datasets.append(dataset)
if self._use_sampling:
assert len(datasets) == len(sampling_probabilities)
dataset = tf.data.experimental.sample_from_datasets(
datasets, sampling_probabilities)
else:
dataset = datasets[0]
return dataset
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for GOCR."""
from typing import List, Optional, Sequence, Tuple, Union
import tensorflow as tf
from official.projects.unified_detector.utils.typing import TensorDict
from official.vision.dataloaders import decoder
class TfExampleDecoder(decoder.Decoder):
"""Tensorflow Example proto decoder."""
def __init__(self,
use_instance_mask: bool = False,
additional_class_names: Optional[Sequence[str]] = None,
additional_regression_names: Optional[Sequence[str]] = None,
num_additional_channels: int = 0):
"""Constructor.
keys_to_features is a dictionary mapping the names of the tf.Example
fields to tf features, possibly with defaults.
Uses fixed length for scalars and variable length for vectors.
Args:
use_instance_mask: if False, prevents decoding of the instance mask, which
can take a lot of resources.
additional_class_names: If not none, a list of additional class names. For
additional class name n, named image/object/${n} are expected to be an
int vector of length one, and are mapped to tensor dict key
groundtruth_${n}.
additional_regression_names: If not none, a list of additional regression
output names. For additional class name n, named image/object/${n} are
expected to be a float vector, and are mapped to tensor dict key
groundtruth_${n}.
num_additional_channels: The number of additional channels of information
present in the tf.Example proto.
"""
self._num_additional_channels = num_additional_channels
self._use_instance_mask = use_instance_mask
self.keys_to_features = {}
# Map names in the final tensor dict (output of `self.decode()`) to names in
# tf examples, e.g. 'groundtruth_text' -> 'image/object/text'
self.name_to_key = {}
if use_instance_mask:
self.keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string),
})
# Now we have lists of standard types.
# To add new features, just add entries here.
# The tuple elements are (example name, tensor name, default value).
# If the items_to_handlers part is already set up use None for
# the tensor name.
# There are other tensor names listed as None which we probably
# want to discuss and specify.
scalar_strings = [
('image/encoded', None, ''),
('image/format', None, 'jpg'),
('image/additional_channels/encoded', None, ''),
('image/additional_channels/format', None, 'png'),
('image/label_type', 'label_type', ''),
('image/key', 'key', ''),
('image/source_id', 'source_id', ''),
]
vector_strings = [
('image/attributes', None, ''),
('image/object/text', 'groundtruth_text', ''),
('image/object/encoded_text', 'groundtruth_encoded_text', ''),
('image/object/vertices', 'groundtruth_vertices', ''),
('image/object/object_type', None, ''),
('image/object/language', 'language', ''),
('image/object/reorderer_type', None, ''),
('image/label_map_path', 'label_map_path', '')
]
scalar_ints = [
('image/height', None, 1),
('image/width', None, 1),
('image/channels', None, 3),
]
vector_ints = [
('image/object/classes', 'groundtruth_classes', 0),
('image/object/frame_id', 'frame_id', 0),
('image/object/track_id', 'track_id', 0),
('image/object/content_type', 'groundtruth_content_type', 0),
]
if additional_class_names:
vector_ints += [('image/object/%s' % name, 'groundtruth_%s' % name, 0)
for name in additional_class_names]
# This one is not yet needed:
# scalar_floats = [
# ]
vector_floats = [
('image/object/weight', 'groundtruth_weight', 0),
('image/object/rbox_tl_x', None, 0),
('image/object/rbox_tl_y', None, 0),
('image/object/rbox_width', None, 0),
('image/object/rbox_height', None, 0),
('image/object/rbox_angle', None, 0),
('image/object/bbox/xmin', None, 0),
('image/object/bbox/xmax', None, 0),
('image/object/bbox/ymin', None, 0),
('image/object/bbox/ymax', None, 0),
]
if additional_regression_names:
vector_floats += [('image/object/%s' % name, 'groundtruth_%s' % name, 0)
for name in additional_regression_names]
self._init_scalar_features(scalar_strings, tf.string)
self._init_vector_features(vector_strings, tf.string)
self._init_scalar_features(scalar_ints, tf.int64)
self._init_vector_features(vector_ints, tf.int64)
self._init_vector_features(vector_floats, tf.float32)
def _init_scalar_features(
self,
feature_list: List[Tuple[str, Optional[str], Union[str, int, float]]],
ftype: tf.dtypes.DType) -> None:
for entry in feature_list:
self.keys_to_features[entry[0]] = tf.io.FixedLenFeature(
(), ftype, default_value=entry[2])
if entry[1] is not None:
self.name_to_key[entry[1]] = entry[0]
def _init_vector_features(
self,
feature_list: List[Tuple[str, Optional[str], Union[str, int, float]]],
ftype: tf.dtypes.DType) -> None:
for entry in feature_list:
self.keys_to_features[entry[0]] = tf.io.VarLenFeature(ftype)
if entry[1] is not None:
self.name_to_key[entry[1]] = entry[0]
def _decode_png_instance_masks(self, keys_to_tensors: TensorDict)-> tf.Tensor:
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: A dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
def decode_png_mask(image_buffer):
image = tf.squeeze(
tf.image.decode_image(image_buffer, channels=1), axis=2)
image.set_shape([None, None])
image = tf.to_float(tf.greater(image, 0))
return image
png_masks = keys_to_tensors['image/object/mask']
height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width']
if isinstance(png_masks, tf.SparseTensor):
png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
return tf.cond(
tf.greater(tf.size(png_masks), 0),
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))
def _decode_image(self,
parsed_tensors: TensorDict,
channel: int = 3) -> TensorDict:
"""Decodes the image and set its shape (H, W are dynamic; C is fixed)."""
image = tf.io.decode_image(parsed_tensors['image/encoded'],
channels=channel)
image.set_shape([None, None, channel])
return {'image': image}
def _decode_additional_channels(self,
parsed_tensors: TensorDict,
channel: int = 3) -> TensorDict:
"""Decodes the additional channels and set its static shape."""
channels = tf.io.decode_image(
parsed_tensors['image/additional_channels/encoded'], channels=channel)
channels.set_shape([None, None, channel])
return {'additional_channels': channels}
def _decode_boxes(self, parsed_tensors: TensorDict) -> TensorDict:
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin = parsed_tensors['image/object/bbox/xmin']
xmax = parsed_tensors['image/object/bbox/xmax']
ymin = parsed_tensors['image/object/bbox/ymin']
ymax = parsed_tensors['image/object/bbox/ymax']
return {
'groundtruth_aligned_boxes': tf.stack([ymin, xmin, ymax, xmax], axis=-1)
}
def _decode_rboxes(self, parsed_tensors: TensorDict) -> TensorDict:
"""Concat rbox coordinates: [left, top, box_width, box_height, angle]."""
top_left_x = parsed_tensors['image/object/rbox_tl_x']
top_left_y = parsed_tensors['image/object/rbox_tl_y']
width = parsed_tensors['image/object/rbox_width']
height = parsed_tensors['image/object/rbox_height']
angle = parsed_tensors['image/object/rbox_angle']
return {
'groundtruth_boxes':
tf.stack([top_left_x, top_left_y, width, height, angle], axis=-1)
}
def _decode_masks(self, parsed_tensors: TensorDict) -> TensorDict:
"""Decode a set of PNG masks to the tf.float32 tensors."""
def _decode_png_mask(png_bytes):
mask = tf.squeeze(
tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
mask = tf.cast(mask, dtype=tf.float32)
mask.set_shape([None, None])
return mask
height = parsed_tensors['image/height']
width = parsed_tensors['image/width']
masks = parsed_tensors['image/object/mask']
masks = tf.cond(
pred=tf.greater(tf.size(input=masks), 0),
true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
return {'groundtruth_instance_masks': masks}
def decode(self, tf_example_string_tensor: tf.string):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: A string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary contains a subset of the following, depends on the inputs:
image: A uint8 tensor of shape [height, width, 3] containing the image.
source_id: A string tensor contains image fingerprint.
key: A string tensor contains the unique sha256 hash key.
label_type: Either `full` or `partial`. `full` means all the text are
fully labeled, `partial` otherwise. Currently, this is used by E2E
model. If an input image is fully labeled, we update the weights of
both the detection and the recognizer. Otherwise, only recognizer part
of the model is trained.
groundtruth_text: A string tensor list, the original transcriptions.
groundtruth_encoded_text: A string tensor list, the class ids for the
atoms in the text, after applying the reordering algorithm, in string
form. For example "90,71,85,69,86,85,93,90,71,91,1,71,85,93,90,71".
This depends on the class label map provided to the conversion
program. These are 0 based, with -1 for OOV symbols.
groundtruth_classes: A int32 tensor of shape [num_boxes] contains the
class id. Note this is 1 based, 0 is reserved for background class.
groundtruth_content_type: A int32 tensor of shape [num_boxes] contains
the content type. Values correspond to PageLayoutEntity::ContentType.
groundtruth_weight: A int32 tensor of shape [num_boxes], either 0 or 1.
If a region has weight 0, it will be ignored when computing the
losses.
groundtruth_boxes: A float tensor of shape [num_boxes, 5] contains the
groundtruth rotated rectangles. Each row is in [left, top, box_width,
box_height, angle] order, absolute coordinates are used.
groundtruth_aligned_boxes: A float tensor of shape [num_boxes, 4]
contains the groundtruth axis-aligned rectangles. Each row is in
[ymin, xmin, ymax, xmax] order. Currently, this is used to store
groundtruth symbol boxes.
groundtruth_vertices: A string tensor list contains encoded normalized
box or polygon coordinates. E.g. `x1,y1,x2,y2,x3,y3,x4,y4`.
groundtruth_instance_masks: A float tensor of shape [num_boxes, height,
width] contains binarized image sized instance segmentation masks.
`1.0` for positive region, `0.0` otherwise. None if not in tfe.
frame_id: A int32 tensor of shape [num_boxes], either `0` or `1`.
`0` means object comes from first image, `1` means second.
track_id: A int32 tensor of shape [num_boxes], where value indicates
identity across frame indices.
additional_channels: A uint8 tensor of shape [H, W, C] representing some
features.
"""
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=self.keys_to_features)
for k in parsed_tensors:
if isinstance(parsed_tensors[k], tf.SparseTensor):
if parsed_tensors[k].dtype == tf.string:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value='')
else:
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value=0)
decoded_tensors = {}
decoded_tensors.update(self._decode_image(parsed_tensors))
decoded_tensors.update(self._decode_rboxes(parsed_tensors))
decoded_tensors.update(self._decode_boxes(parsed_tensors))
if self._use_instance_mask:
decoded_tensors[
'groundtruth_instance_masks'] = self._decode_png_instance_masks(
parsed_tensors)
if self._num_additional_channels:
decoded_tensors.update(self._decode_additional_channels(
parsed_tensors, self._num_additional_channels))
# other attributes:
for key in self.name_to_key:
if key not in decoded_tensors:
decoded_tensors[key] = parsed_tensors[self.name_to_key[key]]
if 'groundtruth_instance_masks' not in decoded_tensors:
decoded_tensors['groundtruth_instance_masks'] = None
return decoded_tensors
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap external code in gin."""
import gin
import gin.tf.external_configurables
import tensorflow as tf
# Tensorflow.
gin.external_configurable(tf.keras.layers.experimental.SyncBatchNormalization)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.projects.unified_detector import external_configurables
from official.projects.unified_detector.configs import ocr_config
from official.projects.unified_detector.tasks import ocr_task
from official.vision import registry_imports
tf-nightly
gin-config
opencv-python==4.1.2.30
absl-py>=1.0.0
shapely>=1.8.1
apache_beam>=2.37.0
matplotlib>=3.5.1
notebook>=6.4.10
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""A binary to run unified detector."""
import json
import os
from typing import Any, Dict, Sequence, Union
from absl import app
from absl import flags
from absl import logging
import cv2
import gin
import numpy as np
import tensorflow as tf
import tqdm
from official.projects.unified_detector import external_configurables # pylint: disable=unused-import
from official.projects.unified_detector.modeling import universal_detector
from official.projects.unified_detector.utils import utilities
# group two lines into a paragraph if affinity score higher than this
_PARA_GROUP_THR = 0.5
# MODEL spec
_GIN_FILE = flags.DEFINE_string(
'gin_file', None, 'Path to the Gin file that defines the model.')
_CKPT_PATH = flags.DEFINE_string(
'ckpt_path', None, 'Path to the checkpoint directory.')
_IMG_SIZE = flags.DEFINE_integer(
'img_size', 1024, 'Size of the image fed to the model.')
# Input & Output
# Note that, all images specified by `img_file` and `img_dir` will be processed.
_IMG_FILE = flags.DEFINE_multi_string('img_file', [], 'Paths to the images.')
_IMG_DIR = flags.DEFINE_multi_string(
'img_dir', [], 'Paths to the image directories.')
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Path for the output.')
_VIS_DIR = flags.DEFINE_string(
'vis_dir', None, 'Path for the visualization output.')
def _preprocess(raw_image: np.ndarray) -> Union[np.ndarray, float]:
"""Convert a raw image to properly resized, padded, and normalized ndarray."""
# (1) convert to tf.Tensor and float32.
img_tensor = tf.convert_to_tensor(raw_image, dtype=tf.float32)
# (2) pad to square.
height, width = img_tensor.shape[:2]
maximum_side = tf.maximum(height, width)
height_pad = maximum_side - height
width_pad = maximum_side - width
img_tensor = tf.pad(
img_tensor, [[0, height_pad], [0, width_pad], [0, 0]],
constant_values=127)
ratio = maximum_side / _IMG_SIZE.value
# (3) resize long side to the maximum length.
img_tensor = tf.image.resize(
img_tensor, (_IMG_SIZE.value, _IMG_SIZE.value))
img_tensor = tf.cast(img_tensor, tf.uint8)
# (4) normalize
img_tensor = utilities.normalize_image_to_range(img_tensor)
# (5) Add batch dimension and return as numpy array.
return tf.expand_dims(img_tensor, 0).numpy(), float(ratio)
def load_model() -> tf.keras.layers.Layer:
gin.parse_config_file(_GIN_FILE.value)
model = universal_detector.UniversalDetector()
ckpt = tf.train.Checkpoint(model=model)
ckpt_path = _CKPT_PATH.value
logging.info('Load ckpt from: %s', ckpt_path)
ckpt.restore(ckpt_path).expect_partial()
return model
def inference(img_file: str, model: tf.keras.layers.Layer) -> Dict[str, Any]:
"""Inference step."""
img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
img_ndarray, ratio = _preprocess(img)
output_dict = model.serve(img_ndarray)
class_tensor = output_dict['classes'].numpy()
mask_tensor = output_dict['masks'].numpy()
group_tensor = output_dict['groups'].numpy()
indices = np.where(class_tensor[0])[0].tolist() # indices of positive slots.
mask_list = [
mask_tensor[0, :, :, index] for index in indices] # List of mask ndarray.
# Form lines and words
lines = []
line_indices = []
for index, mask in tqdm.tqdm(zip(indices, mask_list)):
line = {
'words': [],
'text': '',
}
contours, _ = cv2.findContours(
(mask > 0.).astype(np.uint8),
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)[-2:]
for contour in contours:
if (isinstance(contour, np.ndarray) and
len(contour.shape) == 3 and
contour.shape[0] > 2 and
contour.shape[1] == 1 and
contour.shape[2] == 2):
cnt_list = (contour[:, 0] * ratio).astype(np.int32).tolist()
line['words'].append({'text': '', 'vertices': cnt_list})
else:
logging.error('Invalid contour: %s, discarded', str(contour))
if line['words']:
lines.append(line)
line_indices.append(index)
# Form paragraphs
line_grouping = utilities.DisjointSet(len(line_indices))
affinity = group_tensor[0][line_indices][:, line_indices]
for i1, i2 in zip(*np.where(affinity > _PARA_GROUP_THR)):
line_grouping.union(i1, i2)
line_groups = line_grouping.to_group()
paragraphs = []
for line_group in line_groups:
paragraph = {'lines': []}
for id_ in line_group:
paragraph['lines'].append(lines[id_])
if paragraph:
paragraphs.append(paragraph)
return paragraphs
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Get list of images
img_lists = []
img_lists.extend(_IMG_FILE.value)
for img_dir in _IMG_DIR.value:
img_lists.extend(tf.io.gfile.glob(os.path.join(img_dir, '*')))
logging.info('Total number of input images: %d', len(img_lists))
model = load_model()
vis_dis = _VIS_DIR.value
output = {'annotations': []}
for img_file in tqdm.tqdm(img_lists):
output['annotations'].append({
'image_id': img_file.split('/')[-1].split('.')[0],
'paragraphs': inference(img_file, model),
})
if vis_dis:
key = output['annotations'][-1]['image_id']
paragraphs = output['annotations'][-1]['paragraphs']
img = cv2.cvtColor(cv2.imread(img_file), cv2.COLOR_BGR2RGB)
word_bnds = []
line_bnds = []
para_bnds = []
for paragraph in paragraphs:
paragraph_points_list = []
for line in paragraph['lines']:
line_points_list = []
for word in line['words']:
word_bnds.append(
np.array(word['vertices'], np.int32).reshape((-1, 1, 2)))
line_points_list.extend(word['vertices'])
paragraph_points_list.extend(line_points_list)
line_points = np.array(line_points_list, np.int32) # (N,2)
left = int(np.min(line_points[:, 0]))
top = int(np.min(line_points[:, 1]))
right = int(np.max(line_points[:, 0]))
bottom = int(np.max(line_points[:, 1]))
line_bnds.append(
np.array([[[left, top]], [[right, top]], [[right, bottom]],
[[left, bottom]]], np.int32))
para_points = np.array(paragraph_points_list, np.int32) # (N,2)
left = int(np.min(para_points[:, 0]))
top = int(np.min(para_points[:, 1]))
right = int(np.max(para_points[:, 0]))
bottom = int(np.max(para_points[:, 1]))
para_bnds.append(
np.array([[[left, top]], [[right, top]], [[right, bottom]],
[[left, bottom]]], np.int32))
for name, bnds in zip(['paragraph', 'line', 'word'],
[para_bnds, line_bnds, word_bnds]):
vis = cv2.polylines(img, bnds, True, (0, 0, 255), 2)
cv2.imwrite(os.path.join(vis_dis, f'{key}-{name}.jpg'),
cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))
with tf.io.gfile.GFile(_OUTPUT_PATH.value, mode='w') as f:
f.write(json.dumps(output, ensure_ascii=False, indent=2))
if __name__ == '__main__':
flags.mark_flags_as_required(['gin_file', 'ckpt_path', 'output_path'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import all models.
All model files are imported here so that they can be referenced in Gin. Also,
importing here avoids making ocr_task.py too messy.
"""
# pylint: disable=unused-import
from official.projects.unified_detector.modeling import universal_detector
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for ocr."""
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import gin
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions as cfg
from official.core import task_factory
from official.projects.unified_detector.configs import ocr_config
from official.projects.unified_detector.data_loaders import input_reader
from official.projects.unified_detector.tasks import all_models # pylint: disable=unused-import
from official.projects.unified_detector.utils import typing
NestedTensorDict = typing.NestedTensorDict
ModelType = Union[tf.keras.layers.Layer, tf.keras.Model]
@task_factory.register_task_cls(ocr_config.OcrTaskConfig)
@gin.configurable
class OcrTask(base_task.Task):
"""Defining the OCR training task."""
_loss_items = []
def __init__(self,
params: cfg.TaskConfig,
logging_dir: Optional[str] = None,
name: Optional[str] = None,
model_fn: Callable[..., ModelType] = gin.REQUIRED):
super().__init__(params, logging_dir, name)
self._modef_fn = model_fn
def build_model(self) -> ModelType:
"""Build and return the model, record the loss items as well."""
model = self._modef_fn()
self._loss_items.extend(model.loss_items)
return model
def build_inputs(
self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Build the tf.data.Dataset instance."""
return input_reader.InputFn(is_training=params.is_training)({},
input_context)
def build_metrics(self,
training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
"""Build the metrics (currently, only for loss summaries in TensorBoard)."""
del training
metrics = []
# Add loss items
for name in self._loss_items:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
# TODO(longshangbang): add evaluation metrics
return metrics
def train_step(
self,
inputs: Tuple[NestedTensorDict, NestedTensorDict],
model: ModelType,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[str, tf.Tensor]:
features, labels = inputs
input_dict = {"features": features}
if self.task_config.model_call_needs_labels:
input_dict["labels"] = labels
is_mixed_precision = isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer)
with tf.GradientTape() as tape:
outputs = model(**input_dict, training=True)
loss, loss_dict = model.compute_losses(labels=labels, outputs=outputs)
loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
if is_mixed_precision:
loss = optimizer.get_scaled_loss(loss)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
if is_mixed_precision:
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {"loss": loss}
if metrics:
for m in metrics:
m.update_state(loss_dict[m.name])
return logs
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
# pylint: disable=unused-import
from official.projects.unified_detector import registry_imports
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Typing extension."""
from typing import Dict, Union
import numpy as np
import tensorflow as tf
NpDict = Dict[str, np.ndarray]
FeaturesAndLabelsType = Dict[str, Dict[str, tf.Tensor]]
TensorDict = Dict[Union[str, int], tf.Tensor]
NestedTensorDict = Dict[
Union[str, int],
Union[tf.Tensor,
TensorDict]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment