Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b9e00ebf
Commit
b9e00ebf
authored
Jun 02, 2022
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 02, 2022
Browse files
Internal change
PiperOrigin-RevId: 452578619
parent
3e3b0c64
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3893 additions
and
0 deletions
+3893
-0
official/projects/unified_detector/README.md
official/projects/unified_detector/README.md
+163
-0
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
...ied_detector/configs/gin_files/unified_detector_model.gin
+43
-0
official/projects/unified_detector/configs/gin_files/unified_detector_train.gin
...ied_detector/configs/gin_files/unified_detector_train.gin
+22
-0
official/projects/unified_detector/configs/ocr_config.py
official/projects/unified_detector/configs/ocr_config.py
+78
-0
official/projects/unified_detector/data_conversion/convert.py
...cial/projects/unified_detector/data_conversion/convert.py
+66
-0
official/projects/unified_detector/data_conversion/utils.py
official/projects/unified_detector/data_conversion/utils.py
+182
-0
official/projects/unified_detector/data_loaders/autoaugment.py
...ial/projects/unified_detector/data_loaders/autoaugment.py
+753
-0
official/projects/unified_detector/data_loaders/input_reader.py
...al/projects/unified_detector/data_loaders/input_reader.py
+270
-0
official/projects/unified_detector/data_loaders/tf_example_decoder.py
...jects/unified_detector/data_loaders/tf_example_decoder.py
+320
-0
official/projects/unified_detector/data_loaders/universal_detection_parser.py
...ified_detector/data_loaders/universal_detection_parser.py
+606
-0
official/projects/unified_detector/docs/images/task.png
official/projects/unified_detector/docs/images/task.png
+0
-0
official/projects/unified_detector/external_configurables.py
official/projects/unified_detector/external_configurables.py
+22
-0
official/projects/unified_detector/modeling/universal_detector.py
.../projects/unified_detector/modeling/universal_detector.py
+888
-0
official/projects/unified_detector/registry_imports.py
official/projects/unified_detector/registry_imports.py
+21
-0
official/projects/unified_detector/requirements.txt
official/projects/unified_detector/requirements.txt
+8
-0
official/projects/unified_detector/run_inference.py
official/projects/unified_detector/run_inference.py
+222
-0
official/projects/unified_detector/tasks/all_models.py
official/projects/unified_detector/tasks/all_models.py
+23
-0
official/projects/unified_detector/tasks/ocr_task.py
official/projects/unified_detector/tasks/ocr_task.py
+108
-0
official/projects/unified_detector/train.py
official/projects/unified_detector/train.py
+70
-0
official/projects/unified_detector/utils/typing.py
official/projects/unified_detector/utils/typing.py
+28
-0
No files found.
official/projects/unified_detector/README.md
0 → 100644
View file @
b9e00ebf
# Towards End-to-End Unified Scene Text Detection and Layout Analysis

[

](https://arxiv.org/abs/2203.15143)
Official TensorFlow 2 implementation of the paper
`Towards End-to-End Unified
Scene Text Detection and Layout Analysis`
. If you encounter any issues using the
code, you are welcome to submit them to the Issues tab or send emails directly
to us:
`hiertext@google.com`
.
## Installation
### Set up TensorFlow Models
```
bash
# (Optional) Create and enter a virtual environment
pip3
install
--user
virtualenv
virtualenv
-p
python3 unified_detector
source
./unified_detector/bin/activate
# First clone the TensorFlow Models project:
git clone https://github.com/tensorflow/models.git
# Install the requirements of TensorFlow Models and this repo:
cd
models
pip3
install
-r
official/requirements.txt
pip3
install
-r
official/projects/unified_detector/requirements.txt
# Compile the protos
# If `protoc` is not installed, please follow: https://grpc.io/docs/protoc-installation/
export
PYTHONPATH
=
${
PYTHONPATH
}
:
${
PWD
}
/research/
cd
research/object_detection/
protoc protos/string_int_label_map.proto
--python_out
=
.
```
### Set up Deeplab2
```
bash
# Clone Deeplab2 anywhere you like
cd
<somewhere>
git clone https://github.com/google-research/deeplab2.git
# Compile the protos
protoc deeplab2/
*
.proto
--python_out
=
.
# Add to PYTHONPATH the directory where deeplab2 sits.
export
PYTHONPATH
=
${
PYTHONPATH
}
:
${
PWD
}
```
## Running the model on some images using the provided checkpoint.
### Download the checkpoint
Model | Input Resolution | #object query | line PQ (val) | paragraph PQ (val) | line PQ (test) | paragraph PQ (test)
---------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ------------- | ------------- | ------------------ | -------------- | -------------------
Unified-Detector-Line (
[
ckpt
](
https://storage.cloud.google.com/tf_model_garden/vision/unified_detector/unified_detector_ckpt.tgz
)
) | 1024 | 384 | 61.04 | 52.84 | 62.20 | 53.52
### Demo on single images
```
bash
# run from `models/`
python3
-m
official.projects.unified_detector.run_inference
\
--gin_file
=
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
\
--ckpt_path
=
<path-of-the-ckpt>
\
--img_file
=
<some-image>
\
--output_path
=
<some-directory>/demo.jsonl
\
--vis_dir
=
<some-directory>
```
The output will be stored in jsonl in the same hierarchical format as required
by the evaluation script of the HierText dataset. There will also be
visualizations of the word/line/paragraph boundaries. Note that, the unified
detector produces line-level masks and an affinity matrix for grouping lines
into paragraphs. For visualization purpose, we split each line mask into pixel
groups which are defined as connected components/pixels. We visualize these
groups as
`words`
. They are not necessarily at the word granularity, though. We
visualize lines and paragraphs as groupings of these
`words`
using axis-aligned
bounding boxes.
## Inference and Evaluation on the HierText dataset
### Download the HierText dataset
Clone the
[
HierText repo
](
https://github.com/google-research-datasets/hiertext
)
and download the dataset. The
`requirements.txt`
in this folder already covers
those in the HierText repo, so there is no need to create a new virtual
environment again.
### Inference and eval
The following command will run the model on the validation set and compute the
score. Note that the test set annotation is not released yet, so only validation
set is used here for demo purposes.
#### Inference
```
bash
# Run from `models/`
python3
-m
official.projects.unified_detector.run_inference
\
--gin_file
=
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
\
--ckpt_path
=
<path-of-the-ckpt>
\
--img_dir
=
<the-directory-containing-validation-images>
\
--output_path
=
<some-directory>/validation_output.jsonl
```
#### Evaluation
```
bash
# Run from `hiertext/`
python3 eval.py
\
--gt
=
gt/validation.jsonl
\
--result
=
<some-directory>/validation_output.jsonl
\
--output
=
./validation-score.txt
\
--mask_stride
=
1
\
--eval_lines
\
--eval_paragraphs
\
--num_workers
=
0
```
## Train new models.
First, you will need to convert the HierText dataset into TFrecords:
```
bash
# Run from `models/official/projects/unified_detector/data_conversion`
CUDA_VISIBLE_DEVICES
=
''
python3 convert.py
\
--gt_file
=
/path/to/gt.jsonl
\
--img_dir
=
/path/to/image
\
--out_file
=
/path/to/tfrecords/file-prefix
```
To train the unified detector, run the following script:
```
bash
# Run from `models/`
python3
-m
official.projects.unified_detector.train
\
--mode
=
train
\
--experiment
=
unified_detector
\
--model_dir
=
'<some path>'
\
--gin_file
=
'official/projects/unified_detector/configs/gin_files/unified_detector_train.gin'
\
--gin_file
=
'official/projects/unified_detector/configs/gin_files/unified_detector_model.gin'
\
--gin_params
=
'InputFn.input_paths = ["/path/to/tfrecords/file-prefix*"]'
```
## Citation
Please cite our
[
paper
](
https://arxiv.org/pdf/2203.15143.pdf
)
if you find this
work helpful:
```
@inproceedings{long2022towards,
title={Towards End-to-End Unified Scene Text Detection and Layout Analysis},
author={Long, Shangbang and Qin, Siyang and Panteleev, Dmitry and Bissacco, Alessandro and Fujii, Yasuhisa and Raptis, Michalis},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
```
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
0 → 100644
View file @
b9e00ebf
# Defining the unified detector models.
# Model
## Backbone
num_slots = 384
SyncBatchNormalization.momentum = 0.95
get_max_deep_lab_backbone.num_slots = %num_slots
## Decoder
intermediate_filters = 256
num_entity_class = 3 # C + 1 (bkg) + 1 (void)
_get_decoder_head.atrous_rates = (6, 12, 18)
_get_decoder_head.pixel_space_dim = 128
_get_decoder_head.pixel_space_intermediate = %intermediate_filters
_get_decoder_head.num_classes = %num_entity_class
_get_decoder_head.aux_sem_intermediate = %intermediate_filters
_get_decoder_head.low_level = [
{'feature_key': 'res3', 'channels_project': 64,},
{'feature_key': 'res2', 'channels_project': 32,},]
_get_decoder_head.norm_fn = @SyncBatchNormalization
_get_embed_head.norm_fn = @LayerNorm
# Loss
# pq loss
alpha = 0.75
tau = 0.3
_entity_mask_loss.alpha = %alpha
_instance_discrimination_loss.tau = %tau
_paragraph_grouping_loss.tau = %tau
_paragraph_grouping_loss.loss_mode = 'balanced'
# Other Model setting
UniversalDetector.mask_threshold = 0.4
UniversalDetector.class_threshold = 0.5
UniversalDetector.filter_area = 32
universal_detection_loss_weights.loss_segmentation_word = 1e0
universal_detection_loss_weights.loss_inst_dist = 1e0
universal_detection_loss_weights.loss_mask_id = 1e-4
universal_detection_loss_weights.loss_pq = 3e0
universal_detection_loss_weights.loss_para = 1e0
official/projects/unified_detector/configs/gin_files/unified_detector_train.gin
0 → 100644
View file @
b9e00ebf
# Defining the input pipeline of unified detector.
# ===== ===== Model ===== =====
# Internal import 2.
OcrTask.model_fn = @UniversalDetector
# ===== ===== Data pipeline ===== =====
InputFn.parser_fn = @UniDetectorParserFn
InputFn.dataset_type = 'tfrecord'
InputFn.batch_size = 256
# Internal import 3.
UniDetectorParserFn.output_dimension = 1024
# Simple data augmentation for now.
UniDetectorParserFn.rot90_probability = 0.0
UniDetectorParserFn.use_color_distortion = True
UniDetectorParserFn.crop_min_scale = 0.5
UniDetectorParserFn.crop_max_scale = 1.5
UniDetectorParserFn.crop_min_aspect = 0.8
UniDetectorParserFn.crop_max_aspect = 1.25
UniDetectorParserFn.max_num_instance = 384
official/projects/unified_detector/configs/ocr_config.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OCR tasks and models configurations."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
@
dataclasses
.
dataclass
class
OcrTaskConfig
(
cfg
.
TaskConfig
):
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
model_call_needs_labels
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'unified_detector'
)
def
unified_detector
()
->
cfg
.
ExperimentConfig
:
"""Configurations for trainer of unified detector."""
total_train_steps
=
100000
summary_interval
=
steps_per_loop
=
200
checkpoint_interval
=
2000
warmup_steps
=
1000
config
=
cfg
.
ExperimentConfig
(
# Input pipeline and model are configured through Gin.
task
=
OcrTaskConfig
(
train_data
=
cfg
.
DataConfig
(
is_training
=
True
)),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
total_train_steps
,
steps_per_loop
=
steps_per_loop
,
summary_interval
=
summary_interval
,
checkpoint_interval
=
checkpoint_interval
,
max_to_keep
=
1
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
[
'^((?!depthwise).)*(kernel|weights):0$'
,
],
'exclude_from_weight_decay'
:
[
'(^((?!kernel).)*:0)|(depthwise_kernel)'
,
],
'gradient_clip_norm'
:
10.
,
},
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
1e-3
,
'decay_steps'
:
total_train_steps
-
warmup_steps
,
'alpha'
:
1e-2
,
'offset'
:
warmup_steps
,
},
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_learning_rate'
:
1e-5
,
'warmup_steps'
:
warmup_steps
,
}
},
}),
),
)
return
config
official/projects/unified_detector/data_conversion/convert.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Script to convert HierText to TFExamples.
This script is only intended to run locally.
python3 data_preprocess/convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
"""
import
json
import
os
import
random
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tqdm
import
utils
_GT_FILE
=
flags
.
DEFINE_string
(
'gt_file'
,
None
,
'Path to the GT file'
)
_IMG_DIR
=
flags
.
DEFINE_string
(
'img_dir'
,
None
,
'Path to the image folder.'
)
_OUT_FILE
=
flags
.
DEFINE_string
(
'out_file'
,
None
,
'Path for the tfrecords.'
)
_NUM_SHARD
=
flags
.
DEFINE_integer
(
'num_shard'
,
100
,
'The number of shards of tfrecords.'
)
def
main
(
unused_argv
)
->
None
:
annotations
=
json
.
load
(
open
(
_GT_FILE
.
value
))[
'annotations'
]
random
.
shuffle
(
annotations
)
n_sample
=
len
(
annotations
)
n_shards
=
_NUM_SHARD
.
value
n_sample_per_shard
=
(
n_sample
-
1
)
//
n_shards
+
1
for
shard
in
tqdm
.
tqdm
(
range
(
n_shards
)):
output_path
=
f
'
{
_OUT_FILE
.
value
}
-
{
shard
:
05
}
-
{
n_shards
:
05
}
.tfrecords'
annotation_subset
=
annotations
[
shard
*
n_sample_per_shard
:
(
shard
+
1
)
*
n_sample_per_shard
]
with
tf
.
io
.
TFRecordWriter
(
output_path
)
as
file_writer
:
for
annotation
in
annotation_subset
:
img_file_path
=
os
.
path
.
join
(
_IMG_DIR
.
value
,
f
"
{
annotation
[
'image_id'
]
}
.jpg"
)
tfexample
=
utils
.
convert_to_tfe
(
img_file_path
,
annotation
)
file_writer
.
write
(
tfexample
)
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'gt_file'
,
'img_dir'
,
'out_file'
])
app
.
run
(
main
)
official/projects/unified_detector/data_conversion/utils.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to convert data to TFExamples and store in TFRecords."""
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
def
encode_image
(
image_tensor
:
np
.
ndarray
,
encoding_type
:
str
=
'png'
)
->
Union
[
np
.
ndarray
,
tf
.
Tensor
]:
"""Encode image tensor into byte string."""
if
encoding_type
==
'jpg'
:
image_encoded
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
))
elif
encoding_type
==
'png'
:
image_encoded
=
tf
.
image
.
encode_png
(
tf
.
constant
(
image_tensor
))
else
:
raise
ValueError
(
'Invalid encoding type.'
)
if
tf
.
executing_eagerly
():
image_encoded
=
image_encoded
.
numpy
()
else
:
image_encoded
=
image_encoded
.
eval
()
return
image_encoded
def
int64_feature
(
value
:
Union
[
int
,
List
[
int
]])
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
def
float_feature
(
value
:
Union
[
float
,
List
[
float
]])
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
def
bytes_feature
(
value
:
Union
[
Union
[
bytes
,
str
],
List
[
Union
[
bytes
,
str
]]]
)
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
for
i
in
range
(
len
(
value
)):
if
not
isinstance
(
value
[
i
],
bytes
):
value
[
i
]
=
value
[
i
].
encode
(
'utf-8'
)
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
value
))
def
annotation_to_entities
(
annotation
:
Dict
[
str
,
Any
])
->
List
[
Dict
[
str
,
Any
]]:
"""Flatten the annotation dict to a list of 'entities'."""
entities
=
[]
for
paragraph
in
annotation
[
'paragraphs'
]:
paragraph_id
=
len
(
entities
)
paragraph
[
'type'
]
=
3
# 3 for paragraph
paragraph
[
'parent_id'
]
=
-
1
entities
.
append
(
paragraph
)
for
line
in
paragraph
[
'lines'
]:
line_id
=
len
(
entities
)
line
[
'type'
]
=
2
# 2 for line
line
[
'parent_id'
]
=
paragraph_id
entities
.
append
(
line
)
for
word
in
line
[
'words'
]:
word
[
'type'
]
=
1
# 1 for word
word
[
'parent_id'
]
=
line_id
entities
.
append
(
word
)
return
entities
def
draw_entity_mask
(
entities
:
List
[
Dict
[
str
,
Any
]],
image_shape
:
Tuple
[
int
,
int
,
int
])
->
np
.
ndarray
:
"""Draw entity id mask.
Args:
entities: A list of entity objects. Should be output from
`annotation_to_entities`.
image_shape: The shape of the input image.
Returns:
A (H, W, 3) entity id mask of the same height/width as the image. Each pixel
(i, j, :) encodes the entity id of one pixel. Only word entities are
rendered. 0 for non-text pixels; word entity ids start from 1.
"""
instance_mask
=
np
.
zeros
(
image_shape
,
dtype
=
np
.
uint8
)
for
i
,
entity
in
enumerate
(
entities
):
# only draw word masks
if
entity
[
'type'
]
!=
1
:
continue
vertices
=
np
.
array
(
entity
[
'vertices'
])
# the pixel value is actually 1 + position in entities
entity_id
=
i
+
1
if
entity_id
>=
65536
:
# As entity_id is encoded in the last two channels, it should be less than
# 256**2=65536.
raise
ValueError
(
(
f
'Entity ID overflow:
{
entity_id
}
. Currently only entity_id<65536 '
'are supported.'
))
# use the last two channels to encode the entity id.
color
=
[
0
,
entity_id
//
256
,
entity_id
%
256
]
instance_mask
=
cv2
.
fillPoly
(
instance_mask
,
[
np
.
round
(
vertices
).
astype
(
'int32'
)],
color
)
return
instance_mask
def
convert_to_tfe
(
img_file_name
:
str
,
annotation
:
Dict
[
str
,
Any
])
->
tf
.
train
.
Example
:
"""Convert the annotation dict into a TFExample."""
img
=
cv2
.
imread
(
img_file_name
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
h
,
w
,
c
=
img
.
shape
encoded_img
=
encode_image
(
img
)
entities
=
annotation_to_entities
(
annotation
)
masks
=
draw_entity_mask
(
entities
,
img
.
shape
)
encoded_mask
=
encode_image
(
masks
)
# encode attributes
parent
=
[]
classes
=
[]
content_type
=
[]
text
=
[]
vertices
=
[]
for
entity
in
entities
:
parent
.
append
(
entity
[
'parent_id'
])
classes
.
append
(
entity
[
'type'
])
# 0 for annotated; 8 for not annotated
content_type
.
append
((
0
if
entity
[
'legible'
]
else
8
))
text
.
append
(
entity
.
get
(
'text'
,
''
))
v
=
np
.
array
(
entity
[
'vertices'
])
vertices
.
append
(
','
.
join
(
str
(
float
(
n
))
for
n
in
v
.
reshape
(
-
1
)))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
# input images
'image/encoded'
:
bytes_feature
(
encoded_img
),
# image format
'image/format'
:
bytes_feature
(
'png'
),
# image width
'image/width'
:
int64_feature
([
w
]),
# image height
'image/height'
:
int64_feature
([
h
]),
# image channels
'image/channels'
:
int64_feature
([
c
]),
# image key
'image/source_id'
:
bytes_feature
(
annotation
[
'image_id'
]),
# HxWx3 tensors: channel 2-3 encodes the id of the word entity.
'image/additional_channels/encoded'
:
bytes_feature
(
encoded_mask
),
# format of the additional channels
'image/additional_channels/format'
:
bytes_feature
(
'png'
),
'image/object/parent'
:
int64_feature
(
parent
),
# word / line / paragraph / symbol / ...
'image/object/classes'
:
int64_feature
(
classes
),
# text / handwritten / not-annotated / ...
'image/object/content_type'
:
int64_feature
(
content_type
),
# string text transcription
'image/object/text'
:
bytes_feature
(
text
),
# comma separated coordinates, (x,y) * n
'image/object/vertices'
:
bytes_feature
(
vertices
),
})).
SerializeToString
()
return
example
official/projects/unified_detector/data_loaders/autoaugment.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
This library is adapted from:
`https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py`.
Several changes are made. They are inspired by the TIMM library:
https://github.com/rwightman/pytorch-image-models/tree/master/timm/data
Changes include:
(1) Random Erasing / Cutout is added, and separated from the random augmentation
pool (not sampled as an operation).
(2) For `posterize` and `solarize`, the arguments are changed such that the
level of corruption increases as the `magnitude` argument increases.
(3) `color`, `contrast`, `brightness`, `sharpness` are randomly enhanced or
diminished.
(4) Magnitude is randomly sampled from a normal distribution.
(5) Operations are applied with a probability.
"""
import
inspect
import
math
import
tensorflow
as
tf
import
tensorflow_addons.image
as
tfa_image
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL
=
10.
def
policy_v0
():
"""Autoaugment policy that was used in AutoAugment Paper."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy
=
[
[(
'Equalize'
,
0.8
,
1
),
(
'ShearY'
,
0.8
,
4
)],
[(
'Color'
,
0.4
,
9
),
(
'Equalize'
,
0.6
,
3
)],
[(
'Color'
,
0.4
,
1
),
(
'Rotate'
,
0.6
,
8
)],
[(
'Solarize'
,
0.8
,
3
),
(
'Equalize'
,
0.4
,
7
)],
[(
'Solarize'
,
0.4
,
2
),
(
'Solarize'
,
0.6
,
2
)],
[(
'Color'
,
0.2
,
0
),
(
'Equalize'
,
0.8
,
8
)],
[(
'Equalize'
,
0.4
,
8
),
(
'SolarizeAdd'
,
0.8
,
3
)],
[(
'ShearX'
,
0.2
,
9
),
(
'Rotate'
,
0.6
,
8
)],
[(
'Color'
,
0.6
,
1
),
(
'Equalize'
,
1.0
,
2
)],
[(
'Invert'
,
0.4
,
9
),
(
'Rotate'
,
0.6
,
0
)],
[(
'Equalize'
,
1.0
,
9
),
(
'ShearY'
,
0.6
,
3
)],
[(
'Color'
,
0.4
,
7
),
(
'Equalize'
,
0.6
,
0
)],
[(
'Posterize'
,
0.4
,
6
),
(
'AutoContrast'
,
0.4
,
7
)],
[(
'Solarize'
,
0.6
,
8
),
(
'Color'
,
0.6
,
9
)],
[(
'Solarize'
,
0.2
,
4
),
(
'Rotate'
,
0.8
,
9
)],
[(
'Rotate'
,
1.0
,
7
),
(
'TranslateY'
,
0.8
,
9
)],
[(
'ShearX'
,
0.0
,
0
),
(
'Solarize'
,
0.8
,
4
)],
[(
'ShearY'
,
0.8
,
0
),
(
'Color'
,
0.6
,
4
)],
[(
'Color'
,
1.0
,
0
),
(
'Rotate'
,
0.6
,
2
)],
[(
'Equalize'
,
0.8
,
4
),
(
'Equalize'
,
0.0
,
8
)],
[(
'Equalize'
,
1.0
,
4
),
(
'AutoContrast'
,
0.6
,
2
)],
[(
'ShearY'
,
0.4
,
7
),
(
'SolarizeAdd'
,
0.6
,
7
)],
[(
'Posterize'
,
0.8
,
2
),
(
'Solarize'
,
0.6
,
10
)],
[(
'Solarize'
,
0.6
,
8
),
(
'Equalize'
,
0.6
,
1
)],
[(
'Color'
,
0.8
,
6
),
(
'Rotate'
,
0.4
,
5
)],
]
return
policy
def
policy_vtest
():
"""Autoaugment test policy for debugging."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy
=
[
[(
'TranslateX'
,
1.0
,
4
),
(
'Equalize'
,
1.0
,
10
)],
]
return
policy
# pylint: disable=g-long-lambda
blend
=
tf
.
function
(
lambda
i1
,
i2
,
factor
:
tf
.
cast
(
tfa_image
.
blend
(
tf
.
cast
(
i1
,
tf
.
float32
),
tf
.
cast
(
i2
,
tf
.
float32
),
factor
),
tf
.
uint8
))
# pylint: enable=g-long-lambda
def
random_erase
(
image
,
prob
,
min_area
=
0.02
,
max_area
=
1
/
3
,
min_aspect
=
1
/
3
,
max_aspect
=
10
/
3
,
mode
=
'pixel'
):
"""The random erasing augmentations: https://arxiv.org/pdf/1708.04896.pdf.
This augmentation is applied after image normalization.
Args:
image: Input image after all other augmentation and normalization. It has
type tf.float32.
prob: Probability of applying the random erasing operation.
min_area: As named.
max_area: As named.
min_aspect: As named.
max_aspect: As named.
mode: How the erased area is filled. 'pixel' means white noise (uniform
dist).
Returns:
Randomly erased image.
"""
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
image_area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
# Sample width, height
erase_area
=
tf
.
random
.
uniform
([],
min_area
,
max_area
)
*
image_area
log_max_target_ar
=
tf
.
math
.
log
(
tf
.
minimum
(
tf
.
math
.
divide
(
tf
.
math
.
square
(
tf
.
cast
(
image_width
,
tf
.
float32
)),
erase_area
),
max_aspect
))
log_min_target_ar
=
tf
.
math
.
log
(
tf
.
maximum
(
tf
.
math
.
divide
(
erase_area
,
tf
.
math
.
square
(
tf
.
cast
(
image_height
,
tf
.
float32
))),
min_aspect
))
erase_aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
([],
log_min_target_ar
,
log_max_target_ar
))
erase_h
=
tf
.
cast
(
tf
.
math
.
sqrt
(
erase_area
/
erase_aspect_ratio
),
tf
.
int32
)
erase_w
=
tf
.
cast
(
tf
.
math
.
sqrt
(
erase_area
*
erase_aspect_ratio
),
tf
.
int32
)
# Sample (left, top) of the rectangle to erase
erase_left
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
-
erase_w
,
dtype
=
tf
.
int32
)
erase_top
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_height
-
erase_h
,
dtype
=
tf
.
int32
)
pad_right
=
image_width
-
erase_w
-
erase_left
pad_bottom
=
image_height
-
erase_h
-
erase_top
mask
=
tf
.
pad
(
tf
.
zeros
([
erase_h
,
erase_w
],
dtype
=
image
.
dtype
),
[[
erase_top
,
pad_bottom
],
[
erase_left
,
pad_right
]],
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
# [H, W, 1]
if
mode
==
'pixel'
:
fill
=
tf
.
random
.
truncated_normal
(
tf
.
shape
(
image
),
0.0
,
1.0
,
dtype
=
image
.
dtype
)
else
:
fill
=
tf
.
zeros
(
tf
.
shape
(
image
),
dtype
=
image
.
dtype
)
should_apply_op
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
augmented_image
=
tf
.
cond
(
should_apply_op
,
lambda
:
mask
*
image
+
(
1
-
mask
)
*
fill
,
lambda
:
image
)
return
augmented_image
def
solarize
(
image
,
threshold
=
128
):
# For each pixel in the image, select the pixel
# if the value is less than the threshold.
# Otherwise, subtract 255 from the pixel.
return
tf
.
where
(
image
<
threshold
,
image
,
255
-
image
)
def
solarize_add
(
image
,
addition
=
0
,
threshold
=
128
):
# For each pixel in the image less than threshold
# we add 'addition' amount to it and then clip the
# pixel value to be between 0 and 255. The value
# of 'addition' is between -128 and 128.
added_image
=
tf
.
cast
(
image
,
tf
.
int64
)
+
addition
added_image
=
tf
.
cast
(
tf
.
clip_by_value
(
added_image
,
0
,
255
),
tf
.
uint8
)
return
tf
.
where
(
image
<
threshold
,
added_image
,
image
)
def
color
(
image
,
factor
):
"""Equivalent of PIL Color."""
degenerate
=
tf
.
image
.
grayscale_to_rgb
(
tf
.
image
.
rgb_to_grayscale
(
image
))
return
blend
(
degenerate
,
image
,
factor
)
def
contrast
(
image
,
factor
):
"""Equivalent of PIL Contrast."""
degenerate
=
tf
.
image
.
rgb_to_grayscale
(
image
)
# Cast before calling tf.histogram.
degenerate
=
tf
.
cast
(
degenerate
,
tf
.
int32
)
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist
=
tf
.
histogram_fixed_width
(
degenerate
,
[
0
,
255
],
nbins
=
256
)
mean
=
tf
.
reduce_sum
(
tf
.
cast
(
hist
,
tf
.
float32
))
/
256.0
degenerate
=
tf
.
ones_like
(
degenerate
,
dtype
=
tf
.
float32
)
*
mean
degenerate
=
tf
.
clip_by_value
(
degenerate
,
0.0
,
255.0
)
degenerate
=
tf
.
image
.
grayscale_to_rgb
(
tf
.
cast
(
degenerate
,
tf
.
uint8
))
return
blend
(
degenerate
,
image
,
factor
)
def
brightness
(
image
,
factor
):
"""Equivalent of PIL Brightness."""
degenerate
=
tf
.
zeros_like
(
image
)
return
blend
(
degenerate
,
image
,
factor
)
def
posterize
(
image
,
bits
):
"""Equivalent of PIL Posterize. Smaller `bits` means larger degradation."""
shift
=
8
-
bits
return
tf
.
bitwise
.
left_shift
(
tf
.
bitwise
.
right_shift
(
image
,
shift
),
shift
)
def
rotate
(
image
,
degrees
,
replace
):
"""Rotates the image by degrees either clockwise or counterclockwise.
Args:
image: An image Tensor of type uint8.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
replace: A one or three value 1D tensor to fill empty pixels caused by the
rotate operation.
Returns:
The rotated version of image.
"""
# Convert from degrees to radians.
degrees_to_radians
=
math
.
pi
/
180.0
radians
=
degrees
*
degrees_to_radians
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
if
isinstance
(
replace
,
list
)
or
isinstance
(
replace
,
tuple
):
replace
=
replace
[
0
]
image
=
tfa_image
.
rotate
(
image
,
radians
,
fill_value
=
replace
)
return
image
def
translate_x
(
image
,
pixels
,
replace
):
"""Equivalent of PIL Translate in X dimension."""
return
tfa_image
.
translate_xy
(
image
,
[
-
pixels
,
0
],
replace
)
def
translate_y
(
image
,
pixels
,
replace
):
"""Equivalent of PIL Translate in Y dimension."""
return
tfa_image
.
translate_xy
(
image
,
[
0
,
-
pixels
],
replace
)
def
autocontrast
(
image
):
"""Implements Autocontrast function from PIL using TF ops.
Args:
image: A 3D uint8 tensor.
Returns:
The image after it has had autocontrast applied to it and will be of type
uint8.
"""
def
scale_channel
(
image
):
"""Scale the 2D image using the autocontrast rule."""
# A possibly cheaper version can be done using cumsum/unique_with_counts
# over the histogram values, rather than iterating over the entire image.
# to compute mins and maxes.
lo
=
tf
.
cast
(
tf
.
reduce_min
(
image
),
tf
.
float32
)
hi
=
tf
.
cast
(
tf
.
reduce_max
(
image
),
tf
.
float32
)
# Scale the image, making the lowest value 0 and the highest value 255.
def
scale_values
(
im
):
scale
=
255.0
/
(
hi
-
lo
)
offset
=
-
lo
*
scale
im
=
tf
.
cast
(
im
,
tf
.
float32
)
*
scale
+
offset
im
=
tf
.
clip_by_value
(
im
,
0.0
,
255.0
)
return
tf
.
cast
(
im
,
tf
.
uint8
)
result
=
tf
.
cond
(
hi
>
lo
,
lambda
:
scale_values
(
image
),
lambda
:
image
)
return
result
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1
=
scale_channel
(
image
[:,
:,
0
])
s2
=
scale_channel
(
image
[:,
:,
1
])
s3
=
scale_channel
(
image
[:,
:,
2
])
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
2
)
return
image
def
sharpness
(
image
,
factor
):
"""Implements Sharpness function from PIL using TF ops."""
orig_image
=
image
image
=
tf
.
cast
(
image
,
tf
.
float32
)
# Make image 4D for conv operation.
image
=
tf
.
expand_dims
(
image
,
0
)
# SMOOTH PIL Kernel.
kernel
=
tf
.
constant
([[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
dtype
=
tf
.
float32
,
shape
=
[
3
,
3
,
1
,
1
])
/
13.
# Tile across channel dimension.
kernel
=
tf
.
tile
(
kernel
,
[
1
,
1
,
3
,
1
])
strides
=
[
1
,
1
,
1
,
1
]
with
tf
.
device
(
'/cpu:0'
):
# Some augmentation that uses depth-wise conv will cause crashing when
# training on GPU. See (b/156242594) for details.
degenerate
=
tf
.
nn
.
depthwise_conv2d
(
image
,
kernel
,
strides
,
padding
=
'VALID'
)
degenerate
=
tf
.
clip_by_value
(
degenerate
,
0.0
,
255.0
)
degenerate
=
tf
.
squeeze
(
tf
.
cast
(
degenerate
,
tf
.
uint8
),
[
0
])
# For the borders of the resulting image, fill in the values of the
# original image.
mask
=
tf
.
ones_like
(
degenerate
)
padded_mask
=
tf
.
pad
(
mask
,
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
padded_degenerate
=
tf
.
pad
(
degenerate
,
[[
1
,
1
],
[
1
,
1
],
[
0
,
0
]])
result
=
tf
.
where
(
tf
.
equal
(
padded_mask
,
1
),
padded_degenerate
,
orig_image
)
# Blend the final result.
return
blend
(
result
,
orig_image
,
factor
)
def
equalize
(
image
):
"""Implements Equalize function from PIL using TF ops."""
def
scale_channel
(
im
,
c
):
"""Scale the data in the channel to implement equalize."""
im
=
tf
.
cast
(
im
[:,
:,
c
],
tf
.
int32
)
# Compute the histogram of the image channel.
histo
=
tf
.
histogram_fixed_width
(
im
,
[
0
,
255
],
nbins
=
256
)
# For the purposes of computing the step, filter out the nonzeros.
nonzero
=
tf
.
where
(
tf
.
not_equal
(
histo
,
0
))
nonzero_histo
=
tf
.
reshape
(
tf
.
gather
(
histo
,
nonzero
),
[
-
1
])
step
=
(
tf
.
reduce_sum
(
nonzero_histo
)
-
nonzero_histo
[
-
1
])
//
255
def
build_lut
(
histo
,
step
):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
lut
=
(
tf
.
cumsum
(
histo
)
+
(
step
//
2
))
//
step
# Shift lut, prepending with 0.
lut
=
tf
.
concat
([[
0
],
lut
[:
-
1
]],
0
)
# Clip the counts to be in range. This is done
# in the C code for image.point.
return
tf
.
clip_by_value
(
lut
,
0
,
255
)
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result
=
tf
.
cond
(
tf
.
equal
(
step
,
0
),
lambda
:
im
,
lambda
:
tf
.
gather
(
build_lut
(
histo
,
step
),
im
))
return
tf
.
cast
(
result
,
tf
.
uint8
)
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1
=
scale_channel
(
image
,
0
)
s2
=
scale_channel
(
image
,
1
)
s3
=
scale_channel
(
image
,
2
)
image
=
tf
.
stack
([
s1
,
s2
,
s3
],
2
)
return
image
def
invert
(
image
):
"""Inverts the image pixels."""
image
=
tf
.
convert_to_tensor
(
image
)
return
255
-
image
NAME_TO_FUNC
=
{
'AutoContrast'
:
autocontrast
,
'Equalize'
:
equalize
,
'Invert'
:
invert
,
'Rotate'
:
rotate
,
'Posterize'
:
posterize
,
'PosterizeIncreasing'
:
posterize
,
'Solarize'
:
solarize
,
'SolarizeIncreasing'
:
solarize
,
'SolarizeAdd'
:
solarize_add
,
'Color'
:
color
,
'ColorIncreasing'
:
color
,
'Contrast'
:
contrast
,
'ContrastIncreasing'
:
contrast
,
'Brightness'
:
brightness
,
'BrightnessIncreasing'
:
brightness
,
'Sharpness'
:
sharpness
,
'SharpnessIncreasing'
:
sharpness
,
'ShearX'
:
tfa_image
.
shear_x
,
'ShearY'
:
tfa_image
.
shear_y
,
'TranslateX'
:
translate_x
,
'TranslateY'
:
translate_y
,
'Cutout'
:
tfa_image
.
random_cutout
,
'Hue'
:
tf
.
image
.
adjust_hue
,
}
def
_randomly_negate_tensor
(
tensor
):
"""With 50% prob turn the tensor negative."""
should_flip
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([])
+
0.5
),
tf
.
bool
)
final_tensor
=
tf
.
cond
(
should_flip
,
lambda
:
-
tensor
,
lambda
:
tensor
)
return
final_tensor
def
_rotate_level_to_arg
(
level
):
level
=
(
level
/
_MAX_LEVEL
)
*
30.
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
def
_shrink_level_to_arg
(
level
):
"""Converts level to ratio by which we shrink the image content."""
if
level
==
0
:
return
(
1.0
,)
# if level is zero, do not shrink the image
# Maximum shrinking ratio is 2.9.
level
=
2.
/
(
_MAX_LEVEL
/
level
)
+
0.9
return
(
level
,)
def
_enhance_level_to_arg
(
level
):
return
((
level
/
_MAX_LEVEL
)
*
1.8
+
0.1
,)
def
_enhance_increasing_level_to_arg
(
level
):
level
=
(
level
/
_MAX_LEVEL
)
*
.
9
level
=
1.0
+
_randomly_negate_tensor
(
level
)
return
(
level
,)
def
_shear_level_to_arg
(
level
):
level
=
(
level
/
_MAX_LEVEL
)
*
0.3
# Flip level to negative with 50% chance.
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
def
_translate_level_to_arg
(
level
,
translate_const
):
level
=
level
/
_MAX_LEVEL
*
translate_const
# Flip level to negative with 50% chance.
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
def
_posterize_level_to_arg
(
level
):
return
(
tf
.
cast
(
level
/
_MAX_LEVEL
*
4
,
tf
.
uint8
),)
def
_posterize_increase_level_to_arg
(
level
):
return
(
4
-
_posterize_level_to_arg
(
level
)[
0
],)
def
_solarize_level_to_arg
(
level
):
return
(
tf
.
cast
(
level
/
_MAX_LEVEL
*
256
,
tf
.
uint8
),)
def
_solarize_increase_level_to_arg
(
level
):
return
(
256
-
_solarize_level_to_arg
(
level
)[
0
],)
def
_solarize_add_level_to_arg
(
level
):
return
(
tf
.
cast
(
level
/
_MAX_LEVEL
*
110
,
tf
.
int64
),)
def
_cutout_arg
(
level
,
cutout_size
):
pad_size
=
tf
.
cast
(
level
/
_MAX_LEVEL
*
cutout_size
,
tf
.
int32
)
return
(
2
*
pad_size
,
2
*
pad_size
)
def
level_to_arg
(
hparams
):
return
{
'AutoContrast'
:
lambda
level
:
(),
'Equalize'
:
lambda
level
:
(),
'Invert'
:
lambda
level
:
(),
'Rotate'
:
_rotate_level_to_arg
,
'Posterize'
:
_posterize_level_to_arg
,
'PosterizeIncreasing'
:
_posterize_increase_level_to_arg
,
'Solarize'
:
_solarize_level_to_arg
,
'SolarizeIncreasing'
:
_solarize_increase_level_to_arg
,
'SolarizeAdd'
:
_solarize_add_level_to_arg
,
'Color'
:
_enhance_level_to_arg
,
'ColorIncreasing'
:
_enhance_increasing_level_to_arg
,
'Contrast'
:
_enhance_level_to_arg
,
'ContrastIncreasing'
:
_enhance_increasing_level_to_arg
,
'Brightness'
:
_enhance_level_to_arg
,
'BrightnessIncreasing'
:
_enhance_increasing_level_to_arg
,
'Sharpness'
:
_enhance_level_to_arg
,
'SharpnessIncreasing'
:
_enhance_increasing_level_to_arg
,
'ShearX'
:
_shear_level_to_arg
,
'ShearY'
:
_shear_level_to_arg
,
# pylint:disable=g-long-lambda
'Cutout'
:
lambda
level
:
_cutout_arg
(
level
,
hparams
[
'cutout_const'
]),
# pylint:disable=g-long-lambda
'TranslateX'
:
lambda
level
:
_translate_level_to_arg
(
level
,
hparams
[
'translate_const'
]),
'TranslateY'
:
lambda
level
:
_translate_level_to_arg
(
level
,
hparams
[
'translate_const'
]),
'Hue'
:
lambda
level
:
((
level
/
_MAX_LEVEL
)
*
0.25
,),
# pylint:enable=g-long-lambda
}
def
_parse_policy_info
(
name
,
prob
,
level
,
replace_value
,
augmentation_hparams
):
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
args
=
level_to_arg
(
augmentation_hparams
)[
name
](
level
)
# Add in replace arg if it is required for the function that is being called.
# pytype:disable=wrong-arg-types
if
'replace'
in
inspect
.
signature
(
func
).
parameters
.
keys
():
# pylint: disable=deprecated-method
args
=
tuple
(
list
(
args
)
+
[
replace_value
])
# pytype:enable=wrong-arg-types
return
(
func
,
prob
,
args
)
def
_apply_func_with_prob
(
func
,
image
,
args
,
prob
):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert
isinstance
(
args
,
tuple
)
# Apply the function with probability `prob`.
should_apply_op
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
augmented_image
=
tf
.
cond
(
should_apply_op
,
lambda
:
func
(
image
,
*
args
),
lambda
:
image
)
return
augmented_image
def
select_and_apply_random_policy
(
policies
,
image
):
"""Select a random policy from `policies` and apply it to `image`."""
policy_to_select
=
tf
.
random
.
uniform
([],
maxval
=
len
(
policies
),
dtype
=
tf
.
int32
)
# Note that using tf.case instead of tf.conds would result in significantly
# larger graphs and would even break export for some larger policies.
for
(
i
,
policy
)
in
enumerate
(
policies
):
image
=
tf
.
cond
(
tf
.
equal
(
i
,
policy_to_select
),
lambda
selected_policy
=
policy
:
selected_policy
(
image
),
lambda
:
image
)
return
image
def
build_and_apply_nas_policy
(
policies
,
image
,
augmentation_hparams
):
"""Build a policy from the given policies passed in and apply to image.
Args:
policies: list of lists of tuples in the form `(func, prob, level)`, `func`
is a string name of the augmentation function, `prob` is the probability
of applying the `func` operation, `level` is the input argument for
`func`.
image: tf.Tensor that the resulting policy will be applied to.
augmentation_hparams: Hparams associated with the NAS learned policy.
Returns:
A version of image that now has data augmentation applied to it based on
the `policies` pass into the function.
"""
replace_value
=
[
128
,
128
,
128
]
# func is the string name of the augmentation function, prob is the
# probability of applying the operation and level is the parameter associated
# with the tf op.
# tf_policies are functions that take in an image and return an augmented
# image.
tf_policies
=
[]
for
policy
in
policies
:
tf_policy
=
[]
# Link string name to the correct python function and make sure the correct
# argument is passed into that function.
for
policy_info
in
policy
:
policy_info
=
list
(
policy_info
)
+
[
replace_value
,
augmentation_hparams
]
tf_policy
.
append
(
_parse_policy_info
(
*
policy_info
))
# Now build the tf policy that will apply the augmentation procedue
# on image.
def
make_final_policy
(
tf_policy_
):
def
final_policy
(
image_
):
for
func
,
prob
,
args
in
tf_policy_
:
image_
=
_apply_func_with_prob
(
func
,
image_
,
args
,
prob
)
return
image_
return
final_policy
tf_policies
.
append
(
make_final_policy
(
tf_policy
))
augmented_image
=
select_and_apply_random_policy
(
tf_policies
,
image
)
return
augmented_image
def
distort_image_with_autoaugment
(
image
,
augmentation_name
):
"""Applies the AutoAugment policy to `image`.
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
Args:
image: `Tensor` of shape [height, width, 3] representing an image.
augmentation_name: The name of the AutoAugment policy to use. The available
options are `v0` and `test`. `v0` is the policy used for all of the
results in the paper and was found to achieve the best results on the COCO
dataset. `v1`, `v2` and `v3` are additional good policies found on the
COCO dataset that have slight variation in what operations were used
during the search procedure along with how many operations are applied in
parallel to a single image (2 vs 3).
Returns:
A tuple containing the augmented versions of `image`.
"""
available_policies
=
{
'v0'
:
policy_v0
,
'test'
:
policy_vtest
}
if
augmentation_name
not
in
available_policies
:
raise
ValueError
(
'Invalid augmentation_name: {}'
.
format
(
augmentation_name
))
policy
=
available_policies
[
augmentation_name
]()
# Hparams that will be used for AutoAugment.
augmentation_hparams
=
dict
(
cutout_const
=
100
,
translate_const
=
250
)
return
build_and_apply_nas_policy
(
policy
,
image
,
augmentation_hparams
)
# Cutout is implemented separately.
_RAND_TRANSFORMS
=
[
'AutoContrast'
,
'Equalize'
,
'Invert'
,
'Rotate'
,
'Posterize'
,
'Solarize'
,
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'SolarizeAdd'
,
'Hue'
,
]
# Cutout is implemented separately.
_RAND_INCREASING_TRANSFORMS
=
[
'AutoContrast'
,
'Equalize'
,
'Invert'
,
'Rotate'
,
'PosterizeIncreasing'
,
'SolarizeIncreasing'
,
'SolarizeAdd'
,
'ColorIncreasing'
,
'ContrastIncreasing'
,
'BrightnessIncreasing'
,
'SharpnessIncreasing'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Hue'
,
]
# These augmentations are not suitable for detection task.
_NON_COLOR_DISTORTION_OPS
=
[
'Rotate'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
]
def
distort_image_with_randaugment
(
image
,
num_layers
,
magnitude
,
mag_std
,
inc
,
prob
,
color_only
=
False
):
"""Applies the RandAugment policy to `image`.
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
Args:
image: `Tensor` of shape [height, width, 3] representing an image. The image
should have uint8 type in [0, 255].
num_layers: Integer, the number of augmentation transformations to apply
sequentially to an image. Represented as (N) in the paper. Usually best
values will be in the range [1, 3].
magnitude: Integer, shared magnitude across all augmentation operations.
Represented as (M) in the paper. Usually best values are in the range [5,
30].
mag_std: Randomness of magnitude. The magnitude will be sampled from a
normal distribution on the fly.
inc: Whether to select aug that increases as magnitude increases.
prob: Probability of any aug being applied.
color_only: Whether only apply operations that distort color and do not
change spatial layouts.
Returns:
The augmented version of `image`.
"""
replace_value
=
[
128
]
*
3
augmentation_hparams
=
dict
(
cutout_const
=
40
,
translate_const
=
100
)
available_ops
=
_RAND_INCREASING_TRANSFORMS
if
inc
else
_RAND_TRANSFORMS
if
color_only
:
available_ops
=
list
(
filter
(
lambda
op
:
op
not
in
_NON_COLOR_DISTORTION_OPS
,
available_ops
))
for
layer_num
in
range
(
num_layers
):
op_to_select
=
tf
.
random
.
uniform
([],
maxval
=
len
(
available_ops
),
dtype
=
tf
.
int32
)
random_magnitude
=
tf
.
clip_by_value
(
tf
.
random
.
normal
([],
magnitude
,
mag_std
),
0.
,
_MAX_LEVEL
)
with
tf
.
name_scope
(
'randaug_layer_{}'
.
format
(
layer_num
)):
for
(
i
,
op_name
)
in
enumerate
(
available_ops
):
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
random_magnitude
,
replace_value
,
augmentation_hparams
)
image
=
tf
.
cond
(
tf
.
equal
(
i
,
op_to_select
),
# pylint:disable=g-long-lambda
lambda
s_func
=
func
,
s_args
=
args
:
_apply_func_with_prob
(
s_func
,
image
,
s_args
,
prob
),
# pylint:enable=g-long-lambda
lambda
:
image
)
return
image
official/projects/unified_detector/data_loaders/input_reader.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input data reader.
Creates a tf.data.Dataset object from multiple input sstables and use a
provided data parser function to decode the serialized tf.Example and optionally
run data augmentation.
"""
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Union
import
gin
from
six.moves
import
map
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
research.object_detection.utils
import
label_map_util
from
official.core
import
config_definitions
as
cfg
from
official.projects.unified_detector.data_loaders
import
universal_detection_parser
# pylint: disable=unused-import
FuncType
=
Callable
[...,
Any
]
@
gin
.
configurable
(
denylist
=
[
'is_training'
])
class
InputFn
(
object
):
"""Input data reader class.
Creates a tf.data.Dataset object from multiple datasets (optionally performs
weighted sampling between different datasets), parses the tf.Example message
using `parser_fn`. The datasets can either be stored in SSTable or TfRecord.
"""
def
__init__
(
self
,
is_training
:
bool
,
batch_size
:
Optional
[
int
]
=
None
,
data_root
:
str
=
''
,
input_paths
:
List
[
str
]
=
gin
.
REQUIRED
,
dataset_type
:
str
=
'tfrecord'
,
use_sampling
:
bool
=
False
,
sampling_weights
:
Optional
[
Sequence
[
Union
[
int
,
float
]]]
=
None
,
cycle_length
:
Optional
[
int
]
=
64
,
shuffle_buffer_size
:
Optional
[
int
]
=
512
,
parser_fn
:
Optional
[
FuncType
]
=
None
,
parser_num_parallel_calls
:
Optional
[
int
]
=
64
,
max_intra_op_parallelism
:
Optional
[
int
]
=
None
,
label_map_proto_path
:
Optional
[
str
]
=
None
,
input_filter_fns
:
Optional
[
List
[
FuncType
]]
=
None
,
input_training_filter_fns
:
Optional
[
Sequence
[
FuncType
]]
=
None
,
dense_to_ragged_batch
:
bool
=
False
,
data_validator_fn
:
Optional
[
Callable
[[
Sequence
[
str
]],
None
]]
=
None
):
"""Input reader constructor.
Args:
is_training: Boolean indicating TRAIN or EVAL.
batch_size: Input data batch size. Ignored if batch size is passed through
params. In that case, this can be None.
data_root: All the relative input paths are based on this location.
input_paths: Input file patterns.
dataset_type: Can be 'sstable' or 'tfrecord'.
use_sampling: Whether to perform weighted sampling between different
datasets.
sampling_weights: Unnormalized sampling weights. The length should be
equal to `input_paths`.
cycle_length: The number of input Datasets to interleave from in parallel.
If set to None tf.data experimental autotuning is used.
shuffle_buffer_size: The random shuffle buffer size.
parser_fn: The function to run decoding and data augmentation. The
function takes `is_training` as an input, which is passed from here.
parser_num_parallel_calls: The number of parallel calls for `parser_fn`.
The number of CPU cores is the suggested value. If set to None tf.data
experimental autotuning is used.
max_intra_op_parallelism: if set limits the max intra op parallelism of
functions run on slices of the input.
label_map_proto_path: Path to a StringIntLabelMap which will be used to
decode the input data.
input_filter_fns: A list of functions on the dataset points which returns
true for valid data.
input_training_filter_fns: A list of functions on the dataset points which
returns true for valid data used only for training.
dense_to_ragged_batch: Whether to use ragged batching for MPNN format.
data_validator_fn: If not None, used to validate the data specified by
input_paths.
Raises:
ValueError for invalid input_paths.
"""
self
.
_is_training
=
is_training
if
data_root
:
# If an input path is absolute this does not change it.
input_paths
=
[
os
.
path
.
join
(
data_root
,
value
)
for
value
in
input_paths
]
self
.
_input_paths
=
input_paths
# Disables datasets sampling during eval.
self
.
_batch_size
=
batch_size
if
is_training
:
self
.
_use_sampling
=
use_sampling
else
:
self
.
_use_sampling
=
False
self
.
_sampling_weights
=
sampling_weights
self
.
_cycle_length
=
(
cycle_length
if
cycle_length
else
tf
.
data
.
AUTOTUNE
)
self
.
_shuffle_buffer_size
=
shuffle_buffer_size
self
.
_parser_num_parallel_calls
=
(
parser_num_parallel_calls
if
parser_num_parallel_calls
else
tf
.
data
.
AUTOTUNE
)
self
.
_max_intra_op_parallelism
=
max_intra_op_parallelism
self
.
_label_map_proto_path
=
label_map_proto_path
if
label_map_proto_path
:
name_to_id
=
label_map_util
.
get_label_map_dict
(
label_map_proto_path
)
self
.
_lookup_str_keys
=
list
(
name_to_id
.
keys
())
self
.
_lookup_int_values
=
list
(
name_to_id
.
values
())
self
.
_parser_fn
=
parser_fn
self
.
_input_filter_fns
=
input_filter_fns
or
[]
if
is_training
and
input_training_filter_fns
:
self
.
_input_filter_fns
.
extend
(
input_training_filter_fns
)
self
.
_dataset_type
=
dataset_type
self
.
_dense_to_ragged_batch
=
dense_to_ragged_batch
if
data_validator_fn
is
not
None
:
data_validator_fn
(
self
.
_input_paths
)
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
def
__call__
(
self
,
params
:
cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Read and parse input datasets, return a tf.data.Dataset object."""
# TPUEstimator passes the batch size through params.
if
params
is
not
None
and
'batch_size'
in
params
:
batch_size
=
params
[
'batch_size'
]
else
:
batch_size
=
self
.
_batch_size
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
with
tf
.
name_scope
(
'input_reader'
):
dataset
=
self
.
_build_dataset_from_records
()
dataset_parser_fn
=
self
.
_build_dataset_parser_fn
()
dataset
=
dataset
.
map
(
dataset_parser_fn
,
num_parallel_calls
=
self
.
_parser_num_parallel_calls
)
for
filter_fn
in
self
.
_input_filter_fns
:
dataset
=
dataset
.
filter
(
filter_fn
)
if
self
.
_dense_to_ragged_batch
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
dense_to_ragged_batch
(
batch_size
=
per_replica_batch_size
,
drop_remainder
=
True
))
else
:
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
dataset
def
_fetch_dataset
(
self
,
filename
:
str
)
->
tf
.
data
.
Dataset
:
"""Fetch dataset depending on type.
Args:
filename: Location of dataset.
Returns:
Tf Dataset.
"""
data_cls
=
dataset_fn
.
pick_dataset_fn
(
self
.
_dataset_type
)
data
=
data_cls
([
filename
])
return
data
def
_build_dataset_parser_fn
(
self
)
->
Callable
[...,
tf
.
Tensor
]:
"""Depending on label_map and storage type, build a parser_fn."""
# Parse the fetched records to input tensors for model function.
if
self
.
_label_map_proto_path
:
lookup_initializer
=
tf
.
lookup
.
KeyValueTensorInitializer
(
keys
=
tf
.
constant
(
self
.
_lookup_str_keys
,
dtype
=
tf
.
string
),
values
=
tf
.
constant
(
self
.
_lookup_int_values
,
dtype
=
tf
.
int32
))
name_to_id_table
=
tf
.
lookup
.
StaticHashTable
(
initializer
=
lookup_initializer
,
default_value
=
0
)
parser_fn
=
self
.
_parser_fn
(
is_training
=
self
.
_is_training
,
label_lookup_table
=
name_to_id_table
)
else
:
parser_fn
=
self
.
_parser_fn
(
is_training
=
self
.
_is_training
)
return
parser_fn
def
_build_dataset_from_records
(
self
)
->
tf
.
data
.
Dataset
:
"""Build a tf.data.Dataset object from input SSTables.
If the input data come from multiple SSTables, use the user defined sampling
weights to perform sampling. For example, if the sampling weights is
[1., 2.], the second dataset will be sampled twice more often than the first
one.
Returns:
Dataset built from SSTables.
Raises:
ValueError for inability to find SSTable files.
"""
all_file_patterns
=
[]
if
self
.
_use_sampling
:
for
file_pattern
in
self
.
_input_paths
:
all_file_patterns
.
append
([
file_pattern
])
# Normalize sampling probabilities.
total_weight
=
sum
(
self
.
_sampling_weights
)
sampling_probabilities
=
[
float
(
w
)
/
total_weight
for
w
in
self
.
_sampling_weights
]
else
:
all_file_patterns
.
append
(
self
.
_input_paths
)
datasets
=
[]
for
file_pattern
in
all_file_patterns
:
filenames
=
sum
(
list
(
map
(
tf
.
io
.
gfile
.
glob
,
file_pattern
)),
[])
if
not
filenames
:
raise
ValueError
(
f
'Error trying to read input files for file pattern
{
file_pattern
}
'
)
# Create a dataset of filenames and shuffle the files. In each epoch,
# the file order is shuffled again. This may help if
# per_host_input_for_training = false on TPU.
dataset
=
tf
.
data
.
Dataset
.
list_files
(
file_pattern
,
shuffle
=
self
.
_is_training
)
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
if
self
.
_max_intra_op_parallelism
:
# Disable intra-op parallelism to optimize for throughput instead of
# latency.
options
=
tf
.
data
.
Options
()
options
.
experimental_threading
.
max_intra_op_parallelism
=
1
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
interleave
(
self
.
_fetch_dataset
,
cycle_length
=
self
.
_cycle_length
,
num_parallel_calls
=
self
.
_cycle_length
,
deterministic
=
(
not
self
.
_is_training
))
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
)
datasets
.
append
(
dataset
)
if
self
.
_use_sampling
:
assert
len
(
datasets
)
==
len
(
sampling_probabilities
)
dataset
=
tf
.
data
.
experimental
.
sample_from_datasets
(
datasets
,
sampling_probabilities
)
else
:
dataset
=
datasets
[
0
]
return
dataset
official/projects/unified_detector/data_loaders/tf_example_decoder.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for GOCR."""
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
tensorflow
as
tf
from
official.projects.unified_detector.utils.typing
import
TensorDict
from
official.vision.dataloaders
import
decoder
class
TfExampleDecoder
(
decoder
.
Decoder
):
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
use_instance_mask
:
bool
=
False
,
additional_class_names
:
Optional
[
Sequence
[
str
]]
=
None
,
additional_regression_names
:
Optional
[
Sequence
[
str
]]
=
None
,
num_additional_channels
:
int
=
0
):
"""Constructor.
keys_to_features is a dictionary mapping the names of the tf.Example
fields to tf features, possibly with defaults.
Uses fixed length for scalars and variable length for vectors.
Args:
use_instance_mask: if False, prevents decoding of the instance mask, which
can take a lot of resources.
additional_class_names: If not none, a list of additional class names. For
additional class name n, named image/object/${n} are expected to be an
int vector of length one, and are mapped to tensor dict key
groundtruth_${n}.
additional_regression_names: If not none, a list of additional regression
output names. For additional class name n, named image/object/${n} are
expected to be a float vector, and are mapped to tensor dict key
groundtruth_${n}.
num_additional_channels: The number of additional channels of information
present in the tf.Example proto.
"""
self
.
_num_additional_channels
=
num_additional_channels
self
.
_use_instance_mask
=
use_instance_mask
self
.
keys_to_features
=
{}
# Map names in the final tensor dict (output of `self.decode()`) to names in
# tf examples, e.g. 'groundtruth_text' -> 'image/object/text'
self
.
name_to_key
=
{}
if
use_instance_mask
:
self
.
keys_to_features
.
update
({
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
})
# Now we have lists of standard types.
# To add new features, just add entries here.
# The tuple elements are (example name, tensor name, default value).
# If the items_to_handlers part is already set up use None for
# the tensor name.
# There are other tensor names listed as None which we probably
# want to discuss and specify.
scalar_strings
=
[
(
'image/encoded'
,
None
,
''
),
(
'image/format'
,
None
,
'jpg'
),
(
'image/additional_channels/encoded'
,
None
,
''
),
(
'image/additional_channels/format'
,
None
,
'png'
),
(
'image/label_type'
,
'label_type'
,
''
),
(
'image/key'
,
'key'
,
''
),
(
'image/source_id'
,
'source_id'
,
''
),
]
vector_strings
=
[
(
'image/attributes'
,
None
,
''
),
(
'image/object/text'
,
'groundtruth_text'
,
''
),
(
'image/object/encoded_text'
,
'groundtruth_encoded_text'
,
''
),
(
'image/object/vertices'
,
'groundtruth_vertices'
,
''
),
(
'image/object/object_type'
,
None
,
''
),
(
'image/object/language'
,
'language'
,
''
),
(
'image/object/reorderer_type'
,
None
,
''
),
(
'image/label_map_path'
,
'label_map_path'
,
''
)
]
scalar_ints
=
[
(
'image/height'
,
None
,
1
),
(
'image/width'
,
None
,
1
),
(
'image/channels'
,
None
,
3
),
]
vector_ints
=
[
(
'image/object/classes'
,
'groundtruth_classes'
,
0
),
(
'image/object/frame_id'
,
'frame_id'
,
0
),
(
'image/object/track_id'
,
'track_id'
,
0
),
(
'image/object/content_type'
,
'groundtruth_content_type'
,
0
),
]
if
additional_class_names
:
vector_ints
+=
[(
'image/object/%s'
%
name
,
'groundtruth_%s'
%
name
,
0
)
for
name
in
additional_class_names
]
# This one is not yet needed:
# scalar_floats = [
# ]
vector_floats
=
[
(
'image/object/weight'
,
'groundtruth_weight'
,
0
),
(
'image/object/rbox_tl_x'
,
None
,
0
),
(
'image/object/rbox_tl_y'
,
None
,
0
),
(
'image/object/rbox_width'
,
None
,
0
),
(
'image/object/rbox_height'
,
None
,
0
),
(
'image/object/rbox_angle'
,
None
,
0
),
(
'image/object/bbox/xmin'
,
None
,
0
),
(
'image/object/bbox/xmax'
,
None
,
0
),
(
'image/object/bbox/ymin'
,
None
,
0
),
(
'image/object/bbox/ymax'
,
None
,
0
),
]
if
additional_regression_names
:
vector_floats
+=
[(
'image/object/%s'
%
name
,
'groundtruth_%s'
%
name
,
0
)
for
name
in
additional_regression_names
]
self
.
_init_scalar_features
(
scalar_strings
,
tf
.
string
)
self
.
_init_vector_features
(
vector_strings
,
tf
.
string
)
self
.
_init_scalar_features
(
scalar_ints
,
tf
.
int64
)
self
.
_init_vector_features
(
vector_ints
,
tf
.
int64
)
self
.
_init_vector_features
(
vector_floats
,
tf
.
float32
)
def
_init_scalar_features
(
self
,
feature_list
:
List
[
Tuple
[
str
,
Optional
[
str
],
Union
[
str
,
int
,
float
]]],
ftype
:
tf
.
dtypes
.
DType
)
->
None
:
for
entry
in
feature_list
:
self
.
keys_to_features
[
entry
[
0
]]
=
tf
.
io
.
FixedLenFeature
(
(),
ftype
,
default_value
=
entry
[
2
])
if
entry
[
1
]
is
not
None
:
self
.
name_to_key
[
entry
[
1
]]
=
entry
[
0
]
def
_init_vector_features
(
self
,
feature_list
:
List
[
Tuple
[
str
,
Optional
[
str
],
Union
[
str
,
int
,
float
]]],
ftype
:
tf
.
dtypes
.
DType
)
->
None
:
for
entry
in
feature_list
:
self
.
keys_to_features
[
entry
[
0
]]
=
tf
.
io
.
VarLenFeature
(
ftype
)
if
entry
[
1
]
is
not
None
:
self
.
name_to_key
[
entry
[
1
]]
=
entry
[
0
]
def
_decode_png_instance_masks
(
self
,
keys_to_tensors
:
TensorDict
)
->
tf
.
Tensor
:
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: A dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
def
decode_png_mask
(
image_buffer
):
image
=
tf
.
squeeze
(
tf
.
image
.
decode_image
(
image_buffer
,
channels
=
1
),
axis
=
2
)
image
.
set_shape
([
None
,
None
])
image
=
tf
.
to_float
(
tf
.
greater
(
image
,
0
))
return
image
png_masks
=
keys_to_tensors
[
'image/object/mask'
]
height
=
keys_to_tensors
[
'image/height'
]
width
=
keys_to_tensors
[
'image/width'
]
if
isinstance
(
png_masks
,
tf
.
SparseTensor
):
png_masks
=
tf
.
sparse_tensor_to_dense
(
png_masks
,
default_value
=
''
)
return
tf
.
cond
(
tf
.
greater
(
tf
.
size
(
png_masks
),
0
),
lambda
:
tf
.
map_fn
(
decode_png_mask
,
png_masks
,
dtype
=
tf
.
float32
),
lambda
:
tf
.
zeros
(
tf
.
to_int32
(
tf
.
stack
([
0
,
height
,
width
]))))
def
_decode_image
(
self
,
parsed_tensors
:
TensorDict
,
channel
:
int
=
3
)
->
TensorDict
:
"""Decodes the image and set its shape (H, W are dynamic; C is fixed)."""
image
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/encoded'
],
channels
=
channel
)
image
.
set_shape
([
None
,
None
,
channel
])
return
{
'image'
:
image
}
def
_decode_additional_channels
(
self
,
parsed_tensors
:
TensorDict
,
channel
:
int
=
3
)
->
TensorDict
:
"""Decodes the additional channels and set its static shape."""
channels
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/additional_channels/encoded'
],
channels
=
channel
)
channels
.
set_shape
([
None
,
None
,
channel
])
return
{
'additional_channels'
:
channels
}
def
_decode_boxes
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin
=
parsed_tensors
[
'image/object/bbox/xmin'
]
xmax
=
parsed_tensors
[
'image/object/bbox/xmax'
]
ymin
=
parsed_tensors
[
'image/object/bbox/ymin'
]
ymax
=
parsed_tensors
[
'image/object/bbox/ymax'
]
return
{
'groundtruth_aligned_boxes'
:
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=-
1
)
}
def
_decode_rboxes
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Concat rbox coordinates: [left, top, box_width, box_height, angle]."""
top_left_x
=
parsed_tensors
[
'image/object/rbox_tl_x'
]
top_left_y
=
parsed_tensors
[
'image/object/rbox_tl_y'
]
width
=
parsed_tensors
[
'image/object/rbox_width'
]
height
=
parsed_tensors
[
'image/object/rbox_height'
]
angle
=
parsed_tensors
[
'image/object/rbox_angle'
]
return
{
'groundtruth_boxes'
:
tf
.
stack
([
top_left_x
,
top_left_y
,
width
,
height
,
angle
],
axis
=-
1
)
}
def
_decode_masks
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Decode a set of PNG masks to the tf.float32 tensors."""
def
_decode_png_mask
(
png_bytes
):
mask
=
tf
.
squeeze
(
tf
.
io
.
decode_png
(
png_bytes
,
channels
=
1
,
dtype
=
tf
.
uint8
),
axis
=-
1
)
mask
=
tf
.
cast
(
mask
,
dtype
=
tf
.
float32
)
mask
.
set_shape
([
None
,
None
])
return
mask
height
=
parsed_tensors
[
'image/height'
]
width
=
parsed_tensors
[
'image/width'
]
masks
=
parsed_tensors
[
'image/object/mask'
]
masks
=
tf
.
cond
(
pred
=
tf
.
greater
(
tf
.
size
(
input
=
masks
),
0
),
true_fn
=
lambda
:
tf
.
map_fn
(
_decode_png_mask
,
masks
,
dtype
=
tf
.
float32
),
false_fn
=
lambda
:
tf
.
zeros
([
0
,
height
,
width
],
dtype
=
tf
.
float32
))
return
{
'groundtruth_instance_masks'
:
masks
}
def
decode
(
self
,
tf_example_string_tensor
:
tf
.
string
):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: A string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary contains a subset of the following, depends on the inputs:
image: A uint8 tensor of shape [height, width, 3] containing the image.
source_id: A string tensor contains image fingerprint.
key: A string tensor contains the unique sha256 hash key.
label_type: Either `full` or `partial`. `full` means all the text are
fully labeled, `partial` otherwise. Currently, this is used by E2E
model. If an input image is fully labeled, we update the weights of
both the detection and the recognizer. Otherwise, only recognizer part
of the model is trained.
groundtruth_text: A string tensor list, the original transcriptions.
groundtruth_encoded_text: A string tensor list, the class ids for the
atoms in the text, after applying the reordering algorithm, in string
form. For example "90,71,85,69,86,85,93,90,71,91,1,71,85,93,90,71".
This depends on the class label map provided to the conversion
program. These are 0 based, with -1 for OOV symbols.
groundtruth_classes: A int32 tensor of shape [num_boxes] contains the
class id. Note this is 1 based, 0 is reserved for background class.
groundtruth_content_type: A int32 tensor of shape [num_boxes] contains
the content type. Values correspond to PageLayoutEntity::ContentType.
groundtruth_weight: A int32 tensor of shape [num_boxes], either 0 or 1.
If a region has weight 0, it will be ignored when computing the
losses.
groundtruth_boxes: A float tensor of shape [num_boxes, 5] contains the
groundtruth rotated rectangles. Each row is in [left, top, box_width,
box_height, angle] order, absolute coordinates are used.
groundtruth_aligned_boxes: A float tensor of shape [num_boxes, 4]
contains the groundtruth axis-aligned rectangles. Each row is in
[ymin, xmin, ymax, xmax] order. Currently, this is used to store
groundtruth symbol boxes.
groundtruth_vertices: A string tensor list contains encoded normalized
box or polygon coordinates. E.g. `x1,y1,x2,y2,x3,y3,x4,y4`.
groundtruth_instance_masks: A float tensor of shape [num_boxes, height,
width] contains binarized image sized instance segmentation masks.
`1.0` for positive region, `0.0` otherwise. None if not in tfe.
frame_id: A int32 tensor of shape [num_boxes], either `0` or `1`.
`0` means object comes from first image, `1` means second.
track_id: A int32 tensor of shape [num_boxes], where value indicates
identity across frame indices.
additional_channels: A uint8 tensor of shape [H, W, C] representing some
features.
"""
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
self
.
keys_to_features
)
for
k
in
parsed_tensors
:
if
isinstance
(
parsed_tensors
[
k
],
tf
.
SparseTensor
):
if
parsed_tensors
[
k
].
dtype
==
tf
.
string
:
parsed_tensors
[
k
]
=
tf
.
sparse
.
to_dense
(
parsed_tensors
[
k
],
default_value
=
''
)
else
:
parsed_tensors
[
k
]
=
tf
.
sparse
.
to_dense
(
parsed_tensors
[
k
],
default_value
=
0
)
decoded_tensors
=
{}
decoded_tensors
.
update
(
self
.
_decode_image
(
parsed_tensors
))
decoded_tensors
.
update
(
self
.
_decode_rboxes
(
parsed_tensors
))
decoded_tensors
.
update
(
self
.
_decode_boxes
(
parsed_tensors
))
if
self
.
_use_instance_mask
:
decoded_tensors
[
'groundtruth_instance_masks'
]
=
self
.
_decode_png_instance_masks
(
parsed_tensors
)
if
self
.
_num_additional_channels
:
decoded_tensors
.
update
(
self
.
_decode_additional_channels
(
parsed_tensors
,
self
.
_num_additional_channels
))
# other attributes:
for
key
in
self
.
name_to_key
:
if
key
not
in
decoded_tensors
:
decoded_tensors
[
key
]
=
parsed_tensors
[
self
.
name_to_key
[
key
]]
if
'groundtruth_instance_masks'
not
in
decoded_tensors
:
decoded_tensors
[
'groundtruth_instance_masks'
]
=
None
return
decoded_tensors
official/projects/unified_detector/data_loaders/universal_detection_parser.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data parser for universal detector."""
import
enum
import
functools
from
typing
import
Any
,
Tuple
import
gin
import
tensorflow
as
tf
from
official.projects.unified_detector.data_loaders
import
autoaugment
from
official.projects.unified_detector.data_loaders
import
tf_example_decoder
from
official.projects.unified_detector.utils
import
utilities
from
official.projects.unified_detector.utils.typing
import
NestedTensorDict
from
official.projects.unified_detector.utils.typing
import
TensorDict
@
gin
.
constants_from_enum
class
DetectionClass
(
enum
.
IntEnum
):
"""As in `PageLayoutEntity.EntityType`."""
WORD
=
0
LINE
=
2
PARAGRAPH
=
3
BLOCK
=
4
NOT_ANNOTATED_ID
=
8
def
_erase
(
mask
:
tf
.
Tensor
,
feature
:
tf
.
Tensor
,
min_val
:
float
=
0.
,
max_val
:
float
=
256.
)
->
tf
.
Tensor
:
"""Erase the feature maps with a mask.
Erase feature maps with a mask and replace the erased area with uniform random
noise. The mask can have different size from the feature maps.
Args:
mask: an (h, w) binay mask for pixels to erase with. Value 1 represents
pixels to erase.
feature: the (H, W, C) feature maps to erase from.
min_val: The minimum value of random noise.
max_val: The maximum value of random noise.
Returns:
The (H, W, C) feature maps, with pixels in mask replaced with noises. It's
equal to mask * noise + (1 - mask) * feature.
"""
h
,
w
,
c
=
utilities
.
resolve_shape
(
feature
)
resized_mask
=
tf
.
image
.
resize
(
tf
.
tile
(
tf
.
expand_dims
(
tf
.
cast
(
mask
,
tf
.
float32
),
-
1
),
(
1
,
1
,
c
)),
(
h
,
w
))
erased
=
tf
.
where
(
condition
=
(
resized_mask
>
0.5
),
x
=
tf
.
cast
(
tf
.
random
.
uniform
((
h
,
w
,
c
),
min_val
,
max_val
),
feature
.
dtype
),
y
=
feature
)
return
erased
@
gin
.
configurable
(
denylist
=
[
'is_training'
])
class
UniDetectorParserFn
(
object
):
"""Data parser for universal detector."""
def
__init__
(
self
,
is_training
:
bool
,
output_dimension
:
int
=
1025
,
mask_dimension
:
int
=
-
1
,
max_num_instance
:
int
=
128
,
rot90_probability
:
float
=
0.5
,
use_color_distortion
:
bool
=
True
,
randaug_mag
:
float
=
5.
,
randaug_std
:
float
=
0.5
,
randaug_layer
:
int
=
2
,
randaug_prob
:
float
=
0.5
,
use_cropping
:
bool
=
True
,
crop_min_scale
:
float
=
0.5
,
crop_max_scale
:
float
=
1.5
,
crop_min_aspect
:
float
=
4
/
5
,
crop_max_aspect
:
float
=
5
/
4
,
is_shape_defined
:
bool
=
True
,
use_tpu
:
bool
=
True
,
detection_unit
:
DetectionClass
=
DetectionClass
.
LINE
,
):
"""Constructor.
Args:
is_training: bool indicating TRAIN or EVAL.
output_dimension: The size of input images.
mask_dimension: The size of the output mask. If negative or zero, it will
be set the same as output_dimension.
max_num_instance: The maximum number of instances to output. If it's
negative, padding or truncating will not be performed.
rot90_probability: The probability of rotating multiples of 90 degrees.
use_color_distortion: Whether to apply color distortions to images (via
autoaugment).
randaug_mag: (autoaugment parameter) Color distortion magnitude. Note
that, this value should be set conservatively, as some color distortions
can easily make text illegible e.g. posterize.
randaug_std: (autoaugment parameter) Randomness in color distortion
magnitude.
randaug_layer: (autoaugment parameter) Number of color distortion
operations.
randaug_prob: (autoaugment parameter) Probabilily of applying each
distortion operation.
use_cropping: Bool, whether to use random cropping and resizing in
training.
crop_min_scale: The minimum scale of a random crop.
crop_max_scale: The maximum scale of a random crop. If >1, it means the
images are downsampled.
crop_min_aspect: The minimum aspect ratio of a random crop.
crop_max_aspect: The maximum aspect ratio of a random crop.
is_shape_defined: Whether to define the static shapes for all features and
labels. This must be set to True in TPU training as it requires static
shapes for all tensors.
use_tpu: Whether the inputs are fed to a TPU device.
detection_unit: Whether word or line (or else) is regarded as an entity.
The instance masks will be at word or line level.
"""
if
is_training
and
max_num_instance
<
0
:
raise
ValueError
(
'In TRAIN mode, padding/truncation is required.'
)
self
.
_is_training
=
is_training
self
.
_output_dimension
=
output_dimension
self
.
_mask_dimension
=
(
mask_dimension
if
mask_dimension
>
0
else
output_dimension
)
self
.
_max_num_instance
=
max_num_instance
self
.
_decoder
=
tf_example_decoder
.
TfExampleDecoder
(
num_additional_channels
=
3
,
additional_class_names
=
[
'parent'
])
self
.
_use_color_distortion
=
use_color_distortion
self
.
_rot90_probability
=
rot90_probability
self
.
_randaug_mag
=
randaug_mag
self
.
_randaug_std
=
randaug_std
self
.
_randaug_layer
=
randaug_layer
self
.
_randaug_prob
=
randaug_prob
self
.
_use_cropping
=
use_cropping
self
.
_crop_min_scale
=
crop_min_scale
self
.
_crop_max_scale
=
crop_max_scale
self
.
_crop_min_aspect
=
crop_min_aspect
self
.
_crop_max_aspect
=
crop_max_aspect
self
.
_is_shape_defined
=
is_shape_defined
self
.
_use_tpu
=
use_tpu
self
.
_detection_unit
=
detection_unit
def
__call__
(
self
,
value
:
str
)
->
Tuple
[
TensorDict
,
NestedTensorDict
]:
"""Parsing the data.
Args:
value: The serialized data sample.
Returns:
Two dicts for features and labels.
features:
'source_id': id of the sample; only in EVAL mode
'images': the normalized images, (output_dimension, output_dimension, 3)
labels:
See `_prepare_labels` for its content.
"""
data
=
self
.
_decoder
.
decode
(
value
)
features
=
{}
labels
=
{}
self
.
_preprocess
(
data
,
features
,
labels
)
self
.
_rot90k
(
data
,
features
,
labels
)
self
.
_crop_and_resize
(
data
,
features
,
labels
)
self
.
_color_distortion_and_normalize
(
data
,
features
,
labels
)
self
.
_prepare_labels
(
data
,
features
,
labels
)
self
.
_define_shapes
(
features
,
labels
)
return
features
,
labels
def
_preprocess
(
self
,
data
:
TensorDict
,
features
:
TensorDict
,
unused_labels
:
TensorDict
):
"""All kinds of preprocessing of the decoded data dict."""
# (1) Decode the entity_id_mask: a H*W*1 mask, each pixel equals to
# (1 + position) of the entity in the GT entity list. The IDs
# (which can be larger than 255) are stored in the last two channels.
data
[
'additional_channels'
]
=
tf
.
cast
(
data
[
'additional_channels'
],
tf
.
int32
)
entity_id_mask
=
(
data
[
'additional_channels'
][:,
:,
-
2
:
-
1
]
*
256
+
data
[
'additional_channels'
][:,
:,
-
1
:])
data
[
'entity_id_mask'
]
=
entity_id_mask
# (2) Write image id. Used in evaluation.
if
not
self
.
_use_tpu
:
features
[
'source_id'
]
=
data
[
'source_id'
]
# (3) Block mask: area without annotation
data
[
'image'
]
=
_erase
(
data
[
'additional_channels'
][:,
:,
0
],
data
[
'image'
],
min_val
=
0.
,
max_val
=
256.
)
def
_rot90k
(
self
,
data
:
TensorDict
,
unused_features
:
TensorDict
,
unused_labels
:
TensorDict
):
"""Rotate the image, gt_bboxes, masks by 90k degrees."""
if
not
self
.
_is_training
:
return
rotate_90_choice
=
tf
.
random
.
uniform
([])
def
_rotate
():
"""Rotation.
These will be rotated:
image,
rbox,
entity_id_mask,
TODO(longshangbang): rotate vertices.
Returns:
The rotated tensors of the above fields.
"""
k
=
tf
.
random
.
uniform
([],
1
,
4
,
dtype
=
tf
.
int32
)
h
,
w
,
_
=
utilities
.
resolve_shape
(
data
[
'image'
])
# Image
rotated_img
=
tf
.
image
.
rot90
(
data
[
'image'
],
k
=
k
,
name
=
'image_rot90k'
)
# Box
rotate_box_op
=
functools
.
partial
(
utilities
.
rotate_rboxes90
,
rboxes
=
data
[
'groundtruth_boxes'
],
image_width
=
w
,
image_height
=
h
)
rotated_boxes
=
tf
.
switch_case
(
k
-
1
,
# Indices start with 1.
branch_fns
=
[
lambda
:
rotate_box_op
(
rotation_count
=
1
),
lambda
:
rotate_box_op
(
rotation_count
=
2
),
lambda
:
rotate_box_op
(
rotation_count
=
3
)
])
# Mask
rotated_mask
=
tf
.
image
.
rot90
(
data
[
'entity_id_mask'
],
k
=
k
,
name
=
'mask_rot90k'
)
return
rotated_img
,
rotated_boxes
,
rotated_mask
# pylint: disable=g-long-lambda
(
data
[
'image'
],
data
[
'groundtruth_boxes'
],
data
[
'entity_id_mask'
])
=
tf
.
cond
(
rotate_90_choice
<
self
.
_rot90_probability
,
_rotate
,
lambda
:
(
data
[
'image'
],
data
[
'groundtruth_boxes'
],
data
[
'entity_id_mask'
]))
# pylint: enable=g-long-lambda
def
_crop_and_resize
(
self
,
data
:
TensorDict
,
unused_features
:
TensorDict
,
unused_labels
:
TensorDict
):
"""Perform random cropping and resizing."""
# TODO(longshangbang): resize & translate box as well
# TODO(longshangbang): resize & translate vertices as well
# Get cropping target.
h
,
w
=
utilities
.
resolve_shape
(
data
[
'image'
])[:
2
]
left
,
top
,
crop_w
,
crop_h
,
pad_w
,
pad_h
=
self
.
_get_crop_box
(
tf
.
cast
(
h
,
tf
.
float32
),
tf
.
cast
(
w
,
tf
.
float32
))
# Crop the image. (Pad the images if the crop box is larger than image.)
if
self
.
_is_training
:
# padding left, top, right, bottom
pad_left
=
tf
.
random
.
uniform
([],
0
,
pad_w
+
1
,
dtype
=
tf
.
int32
)
pad_top
=
tf
.
random
.
uniform
([],
0
,
pad_h
+
1
,
dtype
=
tf
.
int32
)
else
:
pad_left
=
0
pad_top
=
0
cropped_img
=
tf
.
image
.
crop_to_bounding_box
(
data
[
'image'
],
top
,
left
,
crop_h
,
crop_w
)
padded_img
=
tf
.
pad
(
cropped_img
,
[[
pad_top
,
pad_h
-
pad_top
],
[
pad_left
,
pad_w
-
pad_left
],
[
0
,
0
]],
constant_values
=
127
)
# Resize images
data
[
'resized_image'
]
=
tf
.
image
.
resize
(
padded_img
,
(
self
.
_output_dimension
,
self
.
_output_dimension
))
data
[
'resized_image'
]
=
tf
.
cast
(
data
[
'resized_image'
],
tf
.
uint8
)
# Crop the masks
cropped_masks
=
tf
.
image
.
crop_to_bounding_box
(
data
[
'entity_id_mask'
],
top
,
left
,
crop_h
,
crop_w
)
padded_masks
=
tf
.
pad
(
cropped_masks
,
[[
pad_top
,
pad_h
-
pad_top
],
[
pad_left
,
pad_w
-
pad_left
],
[
0
,
0
]])
# Resize masks
data
[
'resized_masks'
]
=
tf
.
image
.
resize
(
padded_masks
,
(
self
.
_mask_dimension
,
self
.
_mask_dimension
),
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
data
[
'resized_masks'
]
=
tf
.
squeeze
(
data
[
'resized_masks'
],
-
1
)
def
_get_crop_box
(
self
,
h
:
tf
.
Tensor
,
w
:
tf
.
Tensor
)
->
Tuple
[
Any
,
Any
,
tf
.
Tensor
,
tf
.
Tensor
,
Any
,
Any
]:
"""Get the cropping box.
Args:
h: The height of the image to crop. Should be float type.
w: The width of the image to crop. Should be float type.
Returns:
A tuple representing (left, top, crop_w, crop_h, pad_w, pad_h).
Then in `self._crop_and_resize`, a crop will be extracted with bounding
box from top-left corner (left, top) and with size (crop_w, crop_h). This
crop will then be padded with (pad_w, pad_h) to square sizes.
The outputs also are re-cast to int32 type.
"""
if
not
self
.
_is_training
or
not
self
.
_use_cropping
:
# cast back to integers.
w
=
tf
.
cast
(
w
,
tf
.
int32
)
h
=
tf
.
cast
(
h
,
tf
.
int32
)
side
=
tf
.
maximum
(
w
,
h
)
return
0
,
0
,
w
,
h
,
side
-
w
,
side
-
h
# Get box size
scale
=
tf
.
random
.
uniform
([],
self
.
_crop_min_scale
,
self
.
_crop_max_scale
)
max_edge
=
tf
.
maximum
(
w
,
h
)
long_edge
=
max_edge
*
scale
sqrt_aspect_ratio
=
tf
.
math
.
sqrt
(
tf
.
random
.
uniform
([],
self
.
_crop_min_aspect
,
self
.
_crop_max_aspect
))
box_h
=
long_edge
/
sqrt_aspect_ratio
box_w
=
long_edge
*
sqrt_aspect_ratio
# Get box location
left
=
tf
.
random
.
uniform
([],
0.
,
tf
.
maximum
(
0.
,
w
-
box_w
))
top
=
tf
.
random
.
uniform
([],
0.
,
tf
.
maximum
(
0.
,
h
-
box_h
))
# Get crop & pad
crop_w
=
tf
.
minimum
(
box_w
,
w
-
left
)
crop_h
=
tf
.
minimum
(
box_h
,
h
-
top
)
pad_w
=
box_w
-
crop_w
pad_h
=
box_h
-
crop_h
return
(
tf
.
cast
(
left
,
tf
.
int32
),
tf
.
cast
(
top
,
tf
.
int32
),
tf
.
cast
(
crop_w
,
tf
.
int32
),
tf
.
cast
(
crop_h
,
tf
.
int32
),
tf
.
cast
(
pad_w
,
tf
.
int32
),
tf
.
cast
(
pad_h
,
tf
.
int32
))
def
_color_distortion_and_normalize
(
self
,
data
:
TensorDict
,
features
:
TensorDict
,
unused_labels
:
TensorDict
):
"""Distort colors."""
if
self
.
_is_training
and
self
.
_use_color_distortion
:
data
[
'resized_image'
]
=
autoaugment
.
distort_image_with_randaugment
(
data
[
'resized_image'
],
self
.
_randaug_layer
,
self
.
_randaug_mag
,
self
.
_randaug_std
,
True
,
self
.
_randaug_prob
,
True
)
# Normalize
features
[
'images'
]
=
utilities
.
normalize_image_to_range
(
data
[
'resized_image'
])
def
_prepare_labels
(
self
,
data
:
TensorDict
,
features
:
TensorDict
,
labels
:
TensorDict
):
"""This function prepares the labels.
These following targets are added to labels['segmentation_output']:
'gt_word_score': A (h, w) float32 mask for textness score. 1 for word,
0 for bkg.
These following targets are added to labels['instance_labels']:
'num_instance': A float scalar tensor for the total number of
instances. It is bounded by the maximum number of instances allowed.
It includes the special background instance, so it equals to
(1 + entity numbers).
'masks': A (h, w) int32 mask for entity IDs. The value of each pixel is
the id of the entity it belongs to. A value of `0` means the bkg mask.
'classes': A (max_num,) int tensor indicating the classes of each
instance:
2 for background
1 for text entity
0 for non-object
'masks_sizes': A (max_num,) float tensor for the size of all masks.
'gt_weights': Whether it's difficult / does not have text annotation.
These following targets are added to labels['paragraph_labels']:
'paragraph_ids': A (max_num,) integer tensor for paragprah id. if `-1`,
then no paragraph label for this text.
'has_para_ids': A float scalar; 1.0 if the sample has paragraph labels.
Args:
data: The data dictionary.
features: The feature dict.
labels: The label dict.
"""
# Segmentation labels:
self
.
_get_segmentation_labels
(
data
,
features
,
labels
)
# Instance labels:
self
.
_get_instance_labels
(
data
,
features
,
labels
)
def
_get_segmentation_labels
(
self
,
data
:
TensorDict
,
unused_features
:
TensorDict
,
labels
:
NestedTensorDict
):
labels
[
'segmentation_output'
]
=
{
'gt_word_score'
:
tf
.
cast
((
data
[
'resized_masks'
]
>
0
),
tf
.
float32
)
}
def
_get_instance_labels
(
self
,
data
:
TensorDict
,
features
:
TensorDict
,
labels
:
NestedTensorDict
):
"""Generate the labels for text entity detection."""
labels
[
'instance_labels'
]
=
{}
# (1) Depending on `detection_unit`:
# Convert the word-id map to line-id map or use the word-id map directly
# Word entity ids start from 1 in the map, so pad a -1 at the beginning of
# the parent list to counter this offset.
padded_parent
=
tf
.
concat
(
[
tf
.
constant
([
-
1
]),
tf
.
cast
(
data
[
'groundtruth_parent'
],
tf
.
int32
)],
0
)
if
self
.
_detection_unit
==
DetectionClass
.
WORD
:
entity_id_mask
=
data
[
'resized_masks'
]
elif
self
.
_detection_unit
==
DetectionClass
.
LINE
:
# The pixel value is entity_id + 1, shape = [H, W]; 0 for background.
# correctness:
# 0s in data['resized_masks'] --> padded_parent[0] == -1
# i-th entity in plp.entities --> i+1 in data['resized_masks']
# --> padded_parent[i+1]
# --> data['groundtruth_parent'][i]
# --> the parent of i-th entity
entity_id_mask
=
tf
.
gather
(
padded_parent
,
data
[
'resized_masks'
])
+
1
elif
self
.
_detection_unit
==
DetectionClass
.
PARAGRAPH
:
# directly segmenting paragraphs; two hops here.
entity_id_mask
=
tf
.
gather
(
padded_parent
,
data
[
'resized_masks'
])
+
1
entity_id_mask
=
tf
.
gather
(
padded_parent
,
entity_id_mask
)
+
1
else
:
raise
ValueError
(
f
'No such detection unit:
{
self
.
_detection_unit
}
'
)
data
[
'entity_id_mask'
]
=
entity_id_mask
# (2) Get individual masks for entities.
entity_selection_mask
=
tf
.
equal
(
data
[
'groundtruth_classes'
],
self
.
_detection_unit
)
num_all_entity
=
utilities
.
resolve_shape
(
data
[
'groundtruth_classes'
])[
0
]
# entity_ids is a 1-D tensor for IDs of all entities of a certain type.
entity_ids
=
tf
.
boolean_mask
(
tf
.
range
(
num_all_entity
,
dtype
=
tf
.
int32
),
entity_selection_mask
)
# (N,)
# +1 to match the entity ids in entity_id_mask
entity_ids
=
tf
.
reshape
(
entity_ids
,
(
-
1
,
1
,
1
))
+
1
individual_masks
=
tf
.
expand_dims
(
entity_id_mask
,
0
)
individual_masks
=
tf
.
equal
(
entity_ids
,
individual_masks
)
# (N, H, W), bool
# TODO(longshangbang): replace with real mask sizes computing.
# Currently, we use full-resolution masks for individual_masks. In order to
# compute mask sizes, we need to convert individual_masks to int/float type.
# This will cause OOM because the mask is too large.
masks_sizes
=
tf
.
cast
(
tf
.
reduce_any
(
individual_masks
,
axis
=
[
1
,
2
]),
tf
.
float32
)
# remove empty masks (usually caused by cropping)
non_empty_masks_ids
=
tf
.
not_equal
(
masks_sizes
,
0
)
valid_masks
=
tf
.
boolean_mask
(
individual_masks
,
non_empty_masks_ids
)
valid_entity_ids
=
tf
.
boolean_mask
(
entity_ids
,
non_empty_masks_ids
)[:,
0
,
0
]
# (3) Write num of instance
num_instance
=
tf
.
reduce_sum
(
tf
.
cast
(
non_empty_masks_ids
,
tf
.
float32
))
num_instance_and_bkg
=
num_instance
+
1
if
self
.
_max_num_instance
>=
0
:
num_instance_and_bkg
=
tf
.
minimum
(
num_instance_and_bkg
,
self
.
_max_num_instance
)
labels
[
'instance_labels'
][
'num_instance'
]
=
num_instance_and_bkg
# (4) Write instance masks
num_entity_int
=
tf
.
cast
(
num_instance
,
tf
.
int32
)
max_num_entities
=
self
.
_max_num_instance
-
1
# Spare 1 for bkg.
pad_num
=
tf
.
maximum
(
max_num_entities
-
num_entity_int
,
0
)
padded_valid_masks
=
tf
.
pad
(
valid_masks
,
[[
0
,
pad_num
],
[
0
,
0
],
[
0
,
0
]])
# If there are more instances than allowed, randomly sample some.
# `random_selection_mask` is a 0/1 array; the maximum number of 1 is
# `self._max_num_instance`; if not bound, it's an array with all 1s.
if
self
.
_max_num_instance
>=
0
:
padded_size
=
num_entity_int
+
pad_num
random_selection
=
tf
.
random
.
uniform
((
padded_size
,),
dtype
=
tf
.
float32
)
selected_indices
=
tf
.
math
.
top_k
(
random_selection
,
k
=
max_num_entities
)[
1
]
random_selection_mask
=
tf
.
scatter_nd
(
indices
=
tf
.
expand_dims
(
selected_indices
,
axis
=-
1
),
updates
=
tf
.
ones
((
max_num_entities
,),
dtype
=
tf
.
bool
),
shape
=
(
padded_size
,))
else
:
random_selection_mask
=
tf
.
ones
((
num_entity_int
,),
dtype
=
tf
.
bool
)
random_discard_mask
=
tf
.
logical_not
(
random_selection_mask
)
kept_masks
=
tf
.
boolean_mask
(
padded_valid_masks
,
random_selection_mask
)
erased_masks
=
tf
.
boolean_mask
(
padded_valid_masks
,
random_discard_mask
)
erased_masks
=
tf
.
cast
(
tf
.
reduce_any
(
erased_masks
,
axis
=
0
),
tf
.
float32
)
# erase text instances that are obmitted.
features
[
'images'
]
=
_erase
(
erased_masks
,
features
[
'images'
],
-
1.
,
1.
)
labels
[
'segmentation_output'
][
'gt_word_score'
]
*=
1.
-
erased_masks
kept_masks_and_bkg
=
tf
.
concat
(
[
tf
.
math
.
logical_not
(
tf
.
reduce_any
(
kept_masks
,
axis
=
0
,
keepdims
=
True
)),
# bkg
kept_masks
,
],
0
)
labels
[
'instance_labels'
][
'masks'
]
=
tf
.
argmax
(
kept_masks_and_bkg
,
axis
=
0
)
# (5) Write mask size
# TODO(longshangbang): replace with real masks sizes
masks_sizes
=
tf
.
cast
(
tf
.
reduce_any
(
kept_masks_and_bkg
,
axis
=
[
1
,
2
]),
tf
.
float32
)
labels
[
'instance_labels'
][
'masks_sizes'
]
=
masks_sizes
# (6) Write classes.
classes
=
tf
.
ones
((
num_instance
,),
dtype
=
tf
.
int32
)
classes
=
tf
.
concat
([
tf
.
constant
(
2
,
tf
.
int32
,
(
1
,)),
classes
],
0
)
# bkg
if
self
.
_max_num_instance
>=
0
:
classes
=
utilities
.
truncate_or_pad
(
classes
,
self
.
_max_num_instance
,
0
)
labels
[
'instance_labels'
][
'classes'
]
=
classes
# (7) gt-weights
selected_ids
=
tf
.
boolean_mask
(
valid_entity_ids
,
random_selection_mask
[:
num_entity_int
])
if
self
.
_detection_unit
!=
DetectionClass
.
PARAGRAPH
:
gt_text
=
tf
.
gather
(
data
[
'groundtruth_text'
],
selected_ids
-
1
)
gt_weights
=
tf
.
cast
(
tf
.
strings
.
length
(
gt_text
)
>
0
,
tf
.
float32
)
else
:
text_types
=
tf
.
concat
(
[
tf
.
constant
([
8
]),
tf
.
cast
(
data
[
'groundtruth_content_type'
],
tf
.
int32
),
# TODO(longshangbang): temp solution for tfes with no para labels
tf
.
constant
(
8
,
shape
=
(
1000
,)),
],
0
)
para_types
=
tf
.
gather
(
text_types
,
selected_ids
)
gt_weights
=
tf
.
cast
(
tf
.
not_equal
(
para_types
,
NOT_ANNOTATED_ID
),
tf
.
float32
)
gt_weights
=
tf
.
concat
([
tf
.
constant
(
1.
,
shape
=
(
1
,)),
gt_weights
],
0
)
# bkg
if
self
.
_max_num_instance
>=
0
:
gt_weights
=
utilities
.
truncate_or_pad
(
gt_weights
,
self
.
_max_num_instance
,
0
)
labels
[
'instance_labels'
][
'gt_weights'
]
=
gt_weights
# (8) get paragraph label
# In this step, an array `{p_i}` is generated. `p_i` is an integer that
# indicates the group of paragraph which i-th text belongs to. `p_i` == -1
# if this instance is non-text or it has no paragraph labels.
# word -> line -> paragraph
if
self
.
_detection_unit
==
DetectionClass
.
WORD
:
num_hop
=
2
elif
self
.
_detection_unit
==
DetectionClass
.
LINE
:
num_hop
=
1
elif
self
.
_detection_unit
==
DetectionClass
.
PARAGRAPH
:
num_hop
=
0
else
:
raise
ValueError
(
f
'No such detection unit:
{
self
.
_detection_unit
}
. '
'Note that this error should have been raised in '
'previous lines, not here!'
)
para_ids
=
tf
.
identity
(
selected_ids
)
# == id in plp + 1
for
_
in
range
(
num_hop
):
para_ids
=
tf
.
gather
(
padded_parent
,
para_ids
)
+
1
text_types
=
tf
.
concat
(
[
tf
.
constant
([
8
]),
tf
.
cast
(
data
[
'groundtruth_content_type'
],
tf
.
int32
),
# TODO(longshangbang): tricks for tfes that have not para labels
tf
.
constant
(
8
,
shape
=
(
1000
,)),
],
0
)
para_types
=
tf
.
gather
(
text_types
,
para_ids
)
para_ids
=
para_ids
-
1
# revert to id in plp.entities; -1 for no labels
valid_para
=
tf
.
cast
(
tf
.
not_equal
(
para_types
,
NOT_ANNOTATED_ID
),
tf
.
int32
)
para_ids
=
valid_para
*
para_ids
+
(
1
-
valid_para
)
*
(
-
1
)
para_ids
=
tf
.
concat
([
tf
.
constant
([
-
1
]),
para_ids
],
0
)
# add bkg
has_para_ids
=
tf
.
cast
(
tf
.
reduce_sum
(
valid_para
)
>
0
,
tf
.
float32
)
if
self
.
_max_num_instance
>=
0
:
para_ids
=
utilities
.
truncate_or_pad
(
para_ids
,
self
.
_max_num_instance
,
0
,
-
1
)
labels
[
'paragraph_labels'
]
=
{
'paragraph_ids'
:
para_ids
,
'has_para_ids'
:
has_para_ids
}
def
_define_shapes
(
self
,
features
:
TensorDict
,
labels
:
TensorDict
):
"""Define the tensor shapes for TPU compiling."""
if
not
self
.
_is_shape_defined
:
return
features
[
'images'
]
=
tf
.
ensure_shape
(
features
[
'images'
],
(
self
.
_output_dimension
,
self
.
_output_dimension
,
3
))
labels
[
'segmentation_output'
][
'gt_word_score'
]
=
tf
.
ensure_shape
(
labels
[
'segmentation_output'
][
'gt_word_score'
],
(
self
.
_mask_dimension
,
self
.
_mask_dimension
))
labels
[
'instance_labels'
][
'num_instance'
]
=
tf
.
ensure_shape
(
labels
[
'instance_labels'
][
'num_instance'
],
[])
if
self
.
_max_num_instance
>=
0
:
labels
[
'instance_labels'
][
'masks_sizes'
]
=
tf
.
ensure_shape
(
labels
[
'instance_labels'
][
'masks_sizes'
],
(
self
.
_max_num_instance
,))
labels
[
'instance_labels'
][
'masks'
]
=
tf
.
ensure_shape
(
labels
[
'instance_labels'
][
'masks'
],
(
self
.
_mask_dimension
,
self
.
_mask_dimension
))
labels
[
'instance_labels'
][
'classes'
]
=
tf
.
ensure_shape
(
labels
[
'instance_labels'
][
'classes'
],
(
self
.
_max_num_instance
,))
labels
[
'instance_labels'
][
'gt_weights'
]
=
tf
.
ensure_shape
(
labels
[
'instance_labels'
][
'gt_weights'
],
(
self
.
_max_num_instance
,))
labels
[
'paragraph_labels'
][
'paragraph_ids'
]
=
tf
.
ensure_shape
(
labels
[
'paragraph_labels'
][
'paragraph_ids'
],
(
self
.
_max_num_instance
,))
labels
[
'paragraph_labels'
][
'has_para_ids'
]
=
tf
.
ensure_shape
(
labels
[
'paragraph_labels'
][
'has_para_ids'
],
[])
official/projects/unified_detector/docs/images/task.png
0 → 100644
View file @
b9e00ebf
522 KB
official/projects/unified_detector/external_configurables.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap external code in gin."""
import
gin
import
gin.tf.external_configurables
import
tensorflow
as
tf
# Tensorflow.
gin
.
external_configurable
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
)
official/projects/unified_detector/modeling/universal_detector.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Universal detector implementation."""
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Union
import
gin
import
tensorflow
as
tf
from
deeplab2
import
config_pb2
from
deeplab2.model.decoder
import
max_deeplab
as
max_deeplab_head
from
deeplab2.model.encoder
import
axial_resnet_instances
from
deeplab2.model.loss
import
matchers_ops
from
official.legacy.transformer
import
transformer
from
official.projects.unified_detector.utils
import
typing
from
official.projects.unified_detector.utils
import
utilities
EPSILON
=
1e-6
@
gin
.
configurable
def
universal_detection_loss_weights
(
loss_segmentation_word
:
float
=
1e0
,
loss_inst_dist
:
float
=
1e0
,
loss_mask_id
:
float
=
1e-4
,
loss_pq
:
float
=
3e0
,
loss_para
:
float
=
1e0
)
->
Dict
[
str
,
float
]:
"""A function that returns a dict for the weights of loss terms."""
return
{
"loss_segmentation_word"
:
loss_segmentation_word
,
"loss_inst_dist"
:
loss_inst_dist
,
"loss_mask_id"
:
loss_mask_id
,
"loss_pq"
:
loss_pq
,
"loss_para"
:
loss_para
,
}
@
gin
.
configurable
class
LayerNorm
(
tf
.
keras
.
layers
.
LayerNormalization
):
"""A wrapper to allow passing the `training` argument.
The normalization layers in the MaX-DeepLab implementation are passed with
the `training` argument. This wrapper enables the usage of LayerNorm.
"""
def
call
(
self
,
inputs
:
tf
.
Tensor
,
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
del
training
return
super
().
call
(
inputs
)
@
gin
.
configurable
def
get_max_deep_lab_backbone
(
num_slots
:
int
=
128
):
return
axial_resnet_instances
.
get_model
(
"max_deeplab_s"
,
bn_layer
=
LayerNorm
,
block_group_config
=
{
"drop_path_schedule"
:
"linear"
,
"axial_use_recompute_grad"
:
False
},
backbone_use_transformer_beyond_stride
=
16
,
extra_decoder_use_transformer_beyond_stride
=
16
,
num_mask_slots
=
num_slots
,
max_num_mask_slots
=
num_slots
)
@
gin
.
configurable
class
UniversalDetector
(
tf
.
keras
.
layers
.
Layer
):
"""Univeral Detector."""
loss_items
=
(
"loss_pq"
,
"loss_inst_dist"
,
"loss_para"
,
"loss_mask_id"
,
"loss_segmentation_word"
)
def
__init__
(
self
,
backbone_fn
:
tf
.
keras
.
layers
.
Layer
=
get_max_deep_lab_backbone
,
mask_threshold
:
float
=
0.4
,
class_threshold
:
float
=
0.5
,
filter_area
:
float
=
32
,
**
kwargs
:
Any
):
"""Constructor.
Args:
backbone_fn: The function to initialize a backbone.
mask_threshold: Masks are thresholded with this value.
class_threshold: Classification heads are thresholded with this value.
filter_area: In inference, detections with area smaller than this
threshold will be removed.
**kwargs: other keyword arguments passed to the base class.
"""
super
().
__init__
(
**
kwargs
)
# Model
self
.
_backbone_fn
=
backbone_fn
()
self
.
_decoder
=
_get_decoder_head
()
self
.
_class_embed_head
,
self
.
_para_embed_head
=
_get_embed_head
()
self
.
_para_head
,
self
.
_para_proj
=
_get_para_head
()
# Losses
# self._max_deeplab_loss = _get_max_deeplab_loss()
self
.
_loss_weights
=
universal_detection_loss_weights
()
# Post-processing
self
.
_mask_threshold
=
mask_threshold
self
.
_class_threshold
=
class_threshold
self
.
_filter_area
=
filter_area
def
_preprocess_labels
(
self
,
labels
:
typing
.
TensorDict
):
# Preprocessing
# Converted the integer mask to one-hot embedded masks.
num_instances
=
utilities
.
resolve_shape
(
labels
[
"instance_labels"
][
"masks_sizes"
])[
1
]
labels
[
"instance_labels"
][
"masks"
]
=
tf
.
one_hot
(
labels
[
"instance_labels"
][
"masks"
],
depth
=
num_instances
,
axis
=
1
,
dtype
=
tf
.
float32
)
# (B, N, H, W)
def
compute_losses
(
self
,
labels
:
typing
.
NestedTensorDict
,
outputs
:
typing
.
NestedTensorDict
)
->
Tuple
[
tf
.
Tensor
,
typing
.
NestedTensorDict
]:
"""Computes the loss.
Args:
labels: A dictionary of ground-truth labels.
outputs: Output from self.call().
Returns:
A scalar total loss tensor and a dictionary for individual losses.
"""
loss_dict
=
{}
self
.
_preprocess_labels
(
labels
)
# Main loss: PQ loss.
_entity_mask_loss
(
loss_dict
,
labels
[
"instance_labels"
],
outputs
[
"instance_output"
])
# Auxiliary loss 1: semantic loss
_semantic_loss
(
loss_dict
,
labels
[
"segmentation_output"
],
outputs
[
"segmentation_output"
])
# Auxiliary loss 2: instance discrimination
_instance_discrimination_loss
(
loss_dict
,
labels
[
"instance_labels"
],
outputs
)
# Auxiliary loss 3: mask id
_mask_id_xent_loss
(
loss_dict
,
labels
[
"instance_labels"
],
outputs
)
# Auxiliary loss 4: paragraph grouping
_paragraph_grouping_loss
(
loss_dict
,
labels
,
outputs
)
weighted_loss
=
[
self
.
_loss_weights
[
k
]
*
v
for
k
,
v
in
loss_dict
.
items
()]
total_loss
=
sum
(
weighted_loss
)
return
total_loss
,
loss_dict
def
call
(
self
,
features
:
typing
.
TensorDict
,
training
:
bool
=
False
)
->
typing
.
NestedTensorDict
:
"""Forward pass of the model.
Args:
features: The input features: {"images": tf.Tensor}. Shape = [B, H, W, C]
training: Whether it's training mode.
Returns:
A dictionary of output with this structure:
{
"max_deep_lab": {
All the max deeplab outputs are here, including both backbone and
decoder.
}
"segmentation_output": {
"word_score": tf.Tensor, [B, h, w],
}
"instance_output": {
"cls_logits": tf.Tensor, [B, N, C],
"mask_id_logits": tf.Tensor, [B, H, W, N],
"cls_prob": tf.Tensor, [B, N, C],
"mask_id_prob": tf.Tensor, [B, H, W, N],
}
"postprocessed": {
"classes": A (B, N) tensor for the class ids. Zero for non-firing
slots.
"binary_masks": A (B, H, W, N) tensor for the N binary masks. Masks
for void cls are set to zero.
"confidence": A (B, N) float tensor for the confidence of "classes".
"mask_area": A (B, N) float tensor for the area of each mask.
}
"transformer_group_feature": (B, N, C) float tensor (normalized),
"para_affinity": (B, N, N) float tensor.
}
Class-0 is for void. Class-(C-1) is for background. Class-1~(C-2) is for
valid classes.
"""
# backbone
backbone_output
=
self
.
_backbone_fn
(
features
[
"images"
],
training
)
# split instance embedding and paragraph embedding;
# then perform paragraph grouping
para_fts
=
self
.
_get_para_outputs
(
backbone_output
,
training
)
affinity
=
tf
.
linalg
.
matmul
(
para_fts
,
para_fts
,
transpose_b
=
True
)
# text detection head
decoder_output
=
self
.
_decoder
(
backbone_output
,
training
)
output_dict
=
{
"max_deep_lab"
:
decoder_output
,
"transformer_group_feature"
:
para_fts
,
"para_affinity"
:
affinity
,
}
input_shape
=
utilities
.
resolve_shape
(
features
[
"images"
])
self
.
_get_semantic_outputs
(
output_dict
,
input_shape
)
self
.
_get_instance_outputs
(
output_dict
,
input_shape
)
self
.
_postprocess
(
output_dict
)
return
output_dict
def
_get_para_outputs
(
self
,
outputs
:
typing
.
TensorDict
,
training
:
bool
)
->
tf
.
Tensor
:
"""Apply the paragraph head.
This function first splits the features for instance classification and
instance grouping. Then, the additional grouping branch (transformer layers)
is applied to further encode the grouping features. Finally, a tensor of
normalized grouping features is returned.
Args:
outputs: output dictionary from the backbone.
training: training / eval mode mark.
Returns:
The normalized paragraph embedding vector of shape (B, N, C).
"""
# Project the object embeddings into classification feature and grouping
# feature.
fts
=
outputs
[
"transformer_class_feature"
]
# B,N,C
class_feature
=
self
.
_class_embed_head
(
fts
,
training
)
group_feature
=
self
.
_para_embed_head
(
fts
,
training
)
outputs
[
"transformer_class_feature"
]
=
class_feature
outputs
[
"transformer_group_feature"
]
=
group_feature
# Feed the grouping features into additional group encoding branch.
# First we need to build the attention_bias which is used the standard
# transformer encoder.
input_shape
=
utilities
.
resolve_shape
(
group_feature
)
b
=
input_shape
[
0
]
n
=
int
(
input_shape
[
1
])
seq_len
=
tf
.
constant
(
n
,
shape
=
(
b
,))
padding_mask
=
utilities
.
get_padding_mask_from_valid_lengths
(
seq_len
,
n
,
tf
.
float32
)
attention_bias
=
utilities
.
get_transformer_attention_bias
(
padding_mask
)
group_feature
=
self
.
_para_proj
(
self
.
_para_head
(
group_feature
,
attention_bias
,
None
,
training
))
return
tf
.
math
.
l2_normalize
(
group_feature
,
axis
=-
1
)
def
_get_semantic_outputs
(
self
,
outputs
:
typing
.
NestedTensorDict
,
input_shape
:
tf
.
TensorShape
):
"""Add `segmentation_output` to outputs.
Args:
outputs: A dictionary of outputs.
input_shape: The shape of the input images.
"""
h
,
w
=
input_shape
[
1
:
3
]
# B, H/4, W/4, C
semantic_logits
=
outputs
[
"max_deep_lab"
][
"semantic_logits"
]
textness
,
unused_logits
=
tf
.
split
(
semantic_logits
,
[
2
,
-
1
],
-
1
)
# Channel[0:2], textness. c0: non-textness, c1: textness.
word_score
=
tf
.
nn
.
softmax
(
textness
,
-
1
,
"word_score"
)[:,
:,
:,
1
:
2
]
word_score
=
tf
.
squeeze
(
tf
.
image
.
resize
(
word_score
,
(
h
,
w
)),
-
1
)
# Channel[2:] not used yet
outputs
[
"segmentation_output"
]
=
{
"word_score"
:
word_score
}
def
_get_instance_outputs
(
self
,
outputs
:
typing
.
NestedTensorDict
,
input_shape
:
tf
.
TensorShape
):
"""Add `instance_output` to outputs.
Args:
outputs: A dictionary of outputs.
input_shape: The shape of the input images.
These following fields are added to outputs["instance_output"]:
"cls_logits": tf.Tensor, [B, N, C].
"mask_id_logits": tf.Tensor, [B, H, W, N].
"cls_prob": tf.Tensor, [B, N, C], softmax probability.
"mask_id_prob": tf.Tensor, [B, H, W, N], softmax probability. They are
used in training. Masks are all resized to full resolution.
"""
# Get instance_output
h
,
w
=
input_shape
[
1
:
3
]
## Classes
class_logits
=
outputs
[
"max_deep_lab"
][
"transformer_class_logits"
]
# The MaX-DeepLab repo uses the last logit for void; but we use 0.
# Therefore we shift the logits here.
class_logits
=
tf
.
roll
(
class_logits
,
shift
=
1
,
axis
=-
1
)
class_prob
=
tf
.
nn
.
softmax
(
class_logits
)
## Masks
mask_id_logits
=
outputs
[
"max_deep_lab"
][
"pixel_space_mask_logits"
]
mask_id_prob
=
tf
.
nn
.
softmax
(
mask_id_logits
)
mask_id_logits
=
tf
.
image
.
resize
(
mask_id_logits
,
(
h
,
w
))
mask_id_prob
=
tf
.
image
.
resize
(
mask_id_prob
,
(
h
,
w
))
outputs
[
"instance_output"
]
=
{
"cls_logits"
:
class_logits
,
"mask_id_logits"
:
mask_id_logits
,
"cls_prob"
:
class_prob
,
"mask_id_prob"
:
mask_id_prob
,
}
def
_postprocess
(
self
,
outputs
:
typing
.
NestedTensorDict
):
"""Post-process (filtering) the outputs.
Args:
outputs: A dictionary of outputs.
These following fields are added to outputs["postprocessed"]:
"classes": A (B,N) integer tensor for the class ids.
"binary_masks": A (B, H, W, N) tensor for the N binarized 0/1 masks. Masks
for void cls are set to zero.
"confidence": A (B, N) float tensor for the confidence of "classes".
"mask_area": A (B, N) float tensor for the area of each mask. They are
used in inference / visualization.
"""
# Get postprocessed outputs
outputs
[
"postprocessed"
]
=
{}
## Masks:
mask_id_prob
=
outputs
[
"instance_output"
][
"mask_id_prob"
]
mask_max_prob
=
tf
.
reduce_max
(
mask_id_prob
,
axis
=-
1
,
keepdims
=
True
)
thresholded_binary_masks
=
tf
.
cast
(
tf
.
math
.
logical_and
(
tf
.
equal
(
mask_max_prob
,
mask_id_prob
),
tf
.
greater_equal
(
mask_max_prob
,
self
.
_mask_threshold
)),
tf
.
float32
)
area
=
tf
.
reduce_sum
(
thresholded_binary_masks
,
axis
=
(
1
,
2
))
# (B, N)
## Classification:
cls_prob
=
outputs
[
"instance_output"
][
"cls_prob"
]
cls_max_prob
=
tf
.
reduce_max
(
cls_prob
,
axis
=-
1
)
# B, N
cls_max_id
=
tf
.
cast
(
tf
.
argmax
(
cls_prob
,
axis
=-
1
),
tf
.
float32
)
# B, N
## filtering
c
=
utilities
.
resolve_shape
(
cls_prob
)[
2
]
non_void
=
tf
.
reduce_all
(
tf
.
stack
(
[
tf
.
greater_equal
(
area
,
self
.
_filter_area
),
# mask large enough.
tf
.
not_equal
(
cls_max_id
,
0
),
# class-0 is for non-object.
tf
.
not_equal
(
cls_max_id
,
c
-
1
),
# class-(c-1) is for background (last).
tf
.
greater_equal
(
cls_max_prob
,
self
.
_class_threshold
)
# prob >= thr
],
axis
=-
1
),
axis
=-
1
)
non_void
=
tf
.
cast
(
non_void
,
tf
.
float32
)
# Storing
outputs
[
"postprocessed"
][
"classes"
]
=
tf
.
cast
(
cls_max_id
*
non_void
,
tf
.
int32
)
b
,
n
=
utilities
.
resolve_shape
(
non_void
)
outputs
[
"postprocessed"
][
"binary_masks"
]
=
(
thresholded_binary_masks
*
tf
.
reshape
(
non_void
,
(
b
,
1
,
1
,
n
)))
outputs
[
"postprocessed"
][
"confidence"
]
=
cls_max_prob
outputs
[
"postprocessed"
][
"mask_area"
]
=
area
def
_coloring
(
self
,
masks
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Coloring segmentation masks.
Used in visualization.
Args:
masks: A float binary tensor of shape (B, H, W, N), representing `B`
samples, with `N` masks of size `H*W` each. Each of the `N` masks will
be assigned a random color.
Returns:
A (b, h, w, 3) float tensor in [0., 1.] for the coloring result.
"""
b
,
h
,
w
,
n
=
utilities
.
resolve_shape
(
masks
)
palette
=
tf
.
random
.
uniform
((
1
,
n
,
3
),
0.5
,
1.
)
colored
=
tf
.
reshape
(
tf
.
matmul
(
tf
.
reshape
(
masks
,
(
b
,
-
1
,
n
)),
palette
),
(
b
,
h
,
w
,
3
))
return
colored
def
visualize
(
self
,
outputs
:
typing
.
NestedTensorDict
,
labels
:
Optional
[
typing
.
TensorDict
]
=
None
):
"""Visualizes the outputs and labels.
Args:
outputs: A dictionary of outputs.
labels: A dictionary of labels.
The following dict is added to outputs["visualization"]: {
"instance": {
"pred": A (B, H, W, 3) tensor for the visualized map in [0,1].
"gt": A (B, H, W, 3) tensor for the visualized map in [0,1], if labels
is present.
"concat": Concatenation of "prediction" and "gt" along width axis, if
labels is present. }
"seg-text": {... Similar to above, but the shape is (B, H, W, 1).} } All
of these tensors have a rank of 4 (B, H, W, C).
"""
outputs
[
"visualization"
]
=
{}
# 1. prediction
# 1.1 instance mask
binary_masks
=
outputs
[
"postprocessed"
][
"binary_masks"
]
outputs
[
"visualization"
][
"instance"
]
=
{
"pred"
:
self
.
_coloring
(
binary_masks
),
}
# 1.2 text-seg
outputs
[
"visualization"
][
"seg-text"
]
=
{
"pred"
:
tf
.
expand_dims
(
outputs
[
"segmentation_output"
][
"word_score"
],
-
1
),
}
# 2. labels
if
labels
is
not
None
:
# 2.1 instance mask
# (B, N, H, W) -> (B, H, W, N); the first one is bkg so removed.
gt_masks
=
tf
.
transpose
(
labels
[
"instance_labels"
][
"masks"
][:,
1
:],
(
0
,
2
,
3
,
1
))
outputs
[
"visualization"
][
"instance"
][
"gt"
]
=
self
.
_coloring
(
gt_masks
)
# 2.2 text-seg
outputs
[
"visualization"
][
"seg-text"
][
"gt"
]
=
tf
.
expand_dims
(
labels
[
"segmentation_output"
][
"gt_word_score"
],
-
1
)
# 3. concat
for
v
in
outputs
[
"visualization"
].
values
():
# Resize to make the size align. The prediction always has stride=1
# resolution, so we make gt align with pred instead of vice versa.
v
[
"concat"
]
=
tf
.
concat
(
[
v
[
"pred"
],
tf
.
image
.
resize
(
v
[
"gt"
],
tf
.
shape
(
v
[
"pred"
])[
1
:
3
])],
axis
=
2
)
@
tf
.
function
def
serve
(
self
,
image_tensor
:
tf
.
Tensor
)
->
typing
.
NestedTensorDict
:
"""Method to be exported for SavedModel.
Args:
image_tensor: A float32 normalized tensor representing an image of shape
[1, height, width, channels].
Returns:
Dict of output:
classes: (B, N) int32 tensor == o["postprocessed"]["classes"]
masks: (B, H, W, N) float32 tensor == o["postprocessed"]["binary_masks"]
groups: (B, N, N) float32 tensor == o["para_affinity"]
confidence: A (B, N) float tensor == o["postprocessed"]["confidence"]
mask_area: A (B, N) float tensor == o["postprocessed"]["mask_area"]
"""
features
=
{
"images"
:
image_tensor
}
nn_outputs
=
self
(
features
,
False
)
outputs
=
{
"classes"
:
nn_outputs
[
"postprocessed"
][
"classes"
],
"masks"
:
nn_outputs
[
"postprocessed"
][
"binary_masks"
],
"confidence"
:
nn_outputs
[
"postprocessed"
][
"confidence"
],
"mask_area"
:
nn_outputs
[
"postprocessed"
][
"mask_area"
],
"groups"
:
nn_outputs
[
"para_affinity"
],
}
return
outputs
@
gin
.
configurable
()
def
_get_decoder_head
(
atrous_rates
:
Sequence
[
int
]
=
(
6
,
12
,
18
),
pixel_space_dim
:
int
=
128
,
pixel_space_intermediate
:
int
=
256
,
low_level
:
Sequence
[
Dict
[
str
,
Union
[
str
,
int
]]]
=
({
"feature_key"
:
"res3"
,
"channels_project"
:
64
,
},
{
"feature_key"
:
"res2"
,
"channels_project"
:
32
,
}),
num_classes
=
3
,
aux_sem_intermediate
=
256
,
norm_fn
=
tf
.
keras
.
layers
.
BatchNormalization
,
)
->
max_deeplab_head
.
MaXDeepLab
:
"""Get the MaX-DeepLab prediction head.
Args:
atrous_rates: Dilation rate for astrou conv in the semantic head.
pixel_space_dim: The dimension for the final panoptic features.
pixel_space_intermediate: The dimension for the layer before
`pixel_space_dim` (i.e. the separable 5x5 layer).
low_level: A list of dicts for the feature pyramid in forming the semantic
output. Each dict represents one skip-path from the backbone.
num_classes: Number of classes (entities + bkg) including void. For example,
if we only want to detect word, then `num_classes` = 3 (1 for word, 1 for
bkg, and 1 for void).
aux_sem_intermediate: Similar to `pixel_space_intermediate`, but for the
auxiliary semantic output head.
norm_fn: The normalization function used in the head.
Returns:
A MaX-DeepLab decoder head (as a keras layer).
"""
# Initialize the configs.
configs
=
config_pb2
.
ModelOptions
()
configs
.
decoder
.
feature_key
=
"feature_semantic"
configs
.
decoder
.
atrous_rates
.
extend
(
atrous_rates
)
configs
.
max_deeplab
.
pixel_space_head
.
output_channels
=
pixel_space_dim
configs
.
max_deeplab
.
pixel_space_head
.
head_channels
=
pixel_space_intermediate
for
low_level_config
in
low_level
:
low_level_
=
configs
.
max_deeplab
.
auxiliary_low_level
.
add
()
low_level_
.
feature_key
=
low_level_config
[
"feature_key"
]
low_level_
.
channels_project
=
low_level_config
[
"channels_project"
]
configs
.
max_deeplab
.
auxiliary_semantic_head
.
output_channels
=
num_classes
configs
.
max_deeplab
.
auxiliary_semantic_head
.
head_channels
=
aux_sem_intermediate
return
max_deeplab_head
.
MaXDeepLab
(
configs
.
decoder
,
configs
.
max_deeplab
,
0
,
norm_fn
)
class
PseudoLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Pseudo layer for ablation study.
The `call()` function has the same argument signature as a transformer
encoder stack. `unused_ph1` and `unused_ph2` are place holders for this
purpose. When studying the effectiveness of using transformer as the
grouping branch, we can use this PseudoLayer to replace the transformer to
use as a no-transformer baseline.
To use a single projection layer instead of transformer, simply set `extra_fc`
to True.
"""
def
__init__
(
self
,
extra_fc
:
bool
):
super
().
__init__
(
name
=
"extra_fc"
)
self
.
_extra_fc
=
extra_fc
if
extra_fc
:
self
.
_layer
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
256
,
activation
=
"relu"
),
tf
.
keras
.
layers
.
LayerNormalization
(),
])
def
call
(
self
,
fts
:
tf
.
Tensor
,
unused_ph1
:
Optional
[
tf
.
Tensor
],
unused_ph2
:
Optional
[
tf
.
Tensor
],
training
:
Optional
[
bool
]
=
None
)
->
tf
.
Tensor
:
"""See base class."""
if
self
.
_extra_fc
:
return
self
.
_layer
(
fts
,
training
)
return
fts
@
gin
.
configurable
()
def
_get_embed_head
(
dimension
=
256
,
norm_fn
=
tf
.
keras
.
layers
.
BatchNormalization
)
->
Tuple
[
tf
.
keras
.
Sequential
,
tf
.
keras
.
Sequential
]:
"""Projection layers to get instance & grouping features."""
instance_head
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
dimension
,
use_bias
=
False
),
norm_fn
(),
tf
.
keras
.
layers
.
ReLU
(),
])
grouping_head
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
dimension
,
use_bias
=
False
),
norm_fn
(),
tf
.
keras
.
layers
.
ReLU
(),
])
return
instance_head
,
grouping_head
@
gin
.
configurable
()
def
_get_para_head
(
dimension
=
128
,
num_layer
=
3
,
extra_fc
=
False
)
->
Tuple
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
layers
.
Layer
]:
"""Get the additional para head.
Args:
dimension: the dimension of the final output.
num_layer: the number of transformer layer.
extra_fc: Whether an extra single fully-connected layer is used, when
num_layer=0.
Returns:
an encoder and a projection layer for the grouping features.
"""
if
num_layer
>
0
:
encoder
=
transformer
.
EncoderStack
(
params
=
{
"hidden_size"
:
256
,
"num_hidden_layers"
:
num_layer
,
"num_heads"
:
4
,
"filter_size"
:
512
,
"initializer_gain"
:
1.0
,
"attention_dropout"
:
0.1
,
"relu_dropout"
:
0.1
,
"layer_postprocess_dropout"
:
0.1
,
"allow_ffn_pad"
:
True
,
})
else
:
encoder
=
PseudoLayer
(
extra_fc
)
dense
=
tf
.
keras
.
layers
.
Dense
(
dimension
)
return
encoder
,
dense
def
_dice_sim
(
pred
:
tf
.
Tensor
,
ground_truth
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Dice Coefficient for mask similarity.
Args:
pred: The predicted mask. [B, N, H, W], in [0, 1].
ground_truth: The ground-truth mask. [B, N, H, W], in [0, 1] or {0, 1}.
Returns:
A matrix for the losses: m[b, i, j] is the dice similarity between pred `i`
and gt `j` in batch `b`.
"""
b
,
n
=
utilities
.
resolve_shape
(
pred
)[:
2
]
ground_truth
=
tf
.
reshape
(
tf
.
transpose
(
ground_truth
,
(
0
,
2
,
3
,
1
)),
(
b
,
-
1
,
n
))
# B, HW, N
pred
=
tf
.
reshape
(
pred
,
(
b
,
n
,
-
1
))
# B, N, HW
numerator
=
tf
.
matmul
(
pred
,
ground_truth
)
*
2.
# TODO(longshangbang): The official implementation does not square the scores.
# Need to do experiment to determine which one is better.
denominator
=
(
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
ground_truth
),
1
,
keepdims
=
True
)
+
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
pred
),
2
,
keepdims
=
True
))
return
(
numerator
+
EPSILON
)
/
(
denominator
+
EPSILON
)
def
_semantic_loss
(
loss_dict
:
Dict
[
str
,
tf
.
Tensor
],
labels
:
tf
.
Tensor
,
outputs
:
tf
.
Tensor
,
):
"""Auxiliary semantic loss.
Currently, these losses are added:
(1) text/non-text heatmap
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary containing:
`gt_word_score`: (B, H, W) tensor for the text/non-text map.
outputs: The output dictionary containing:
`word_score`: (B, H, W) prediction tensor for `gt_word_score`
"""
pred
=
tf
.
expand_dims
(
outputs
[
"word_score"
],
1
)
gt
=
tf
.
expand_dims
(
labels
[
"gt_word_score"
],
1
)
loss_dict
[
"loss_segmentation_word"
]
=
1.
-
tf
.
reduce_mean
(
_dice_sim
(
pred
,
gt
))
@
gin
.
configurable
def
_entity_mask_loss
(
loss_dict
:
Dict
[
str
,
tf
.
Tensor
],
labels
:
tf
.
Tensor
,
outputs
:
tf
.
Tensor
,
alpha
:
float
=
gin
.
REQUIRED
):
"""PQ loss for entity-mask training.
This method adds the PQ loss term to loss_dict directly. The match result will
also be stored in outputs (As a [B, N_pred, N_gt] float tensor).
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: A dict containing: `num_instance` - (B,) `masks` - (B, N, H, W)
`classes` - (B, N)
outputs: A dict containing:
`cls_prob`: (B, N, C)
`mask_id_prob`: (B, H, W, N)
`cls_logits`: (B, N, C)
`mask_id_logits`: (B, H, W, N)
alpha: Weight for pos/neg balance.
"""
# Classification score: (B, N, N)
# in batch b, the probability of prediction i being class of gt j, i.e.:
# score[b, i, j] = pred_cls[b, i, gt_cls[b, j]]
gt_cls
=
labels
[
"classes"
]
# (B, N)
pred_cls
=
outputs
[
"cls_prob"
]
# (B, N, C)
b
,
n
=
utilities
.
resolve_shape
(
pred_cls
)[:
2
]
# indices[b, i, j] = gt_cls[b, j]
indices
=
tf
.
tile
(
tf
.
expand_dims
(
gt_cls
,
1
),
(
1
,
n
,
1
))
cls_score
=
tf
.
gather
(
pred_cls
,
tf
.
cast
(
indices
,
tf
.
int32
),
batch_dims
=
2
)
# Mask score (dice): (B, N, N)
# mask_score[b, i, j]: dice-similarity for pred i and gt j in batch b.
mask_score
=
_dice_sim
(
tf
.
transpose
(
outputs
[
"mask_id_prob"
],
(
0
,
3
,
1
,
2
)),
labels
[
"masks"
])
# Get similarity matrix and matching.
# padded mask[b, j, i] = -1 << other scores, if i >= num_instance[b]
similarity
=
cls_score
*
mask_score
padded_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
range
(
n
),
(
1
,
1
,
n
)),
tf
.
float32
)
padded_mask
=
tf
.
cast
(
tf
.
math
.
greater_equal
(
padded_mask
,
tf
.
reshape
(
labels
[
"num_instance"
],
(
b
,
1
,
1
))),
tf
.
float32
)
# The constant value for padding has no effect.
masked_similarity
=
similarity
*
(
1.
-
padded_mask
)
+
padded_mask
*
(
-
1.
)
matched_mask
=
matchers_ops
.
hungarian_matching
(
-
masked_similarity
)
matched_mask
=
tf
.
cast
(
matched_mask
,
tf
.
float32
)
*
(
1
-
padded_mask
)
outputs
[
"matched_mask"
]
=
matched_mask
# Pos loss
loss_pos
=
(
tf
.
stop_gradient
(
cls_score
)
*
(
-
mask_score
)
+
tf
.
stop_gradient
(
mask_score
)
*
(
-
tf
.
math
.
log
(
cls_score
)))
loss_pos
=
tf
.
reduce_sum
(
loss_pos
*
matched_mask
,
axis
=
[
1
,
2
])
# (B,)
# Neg loss
matched_pred
=
tf
.
cast
(
tf
.
reduce_sum
(
matched_mask
,
axis
=
2
)
>
0
,
tf
.
float32
)
# (B, N)
# 0 for void class
log_loss
=
-
tf
.
nn
.
log_softmax
(
outputs
[
"cls_logits"
])[:,
:,
0
]
# (B, N)
loss_neg
=
tf
.
reduce_sum
(
log_loss
*
(
1.
-
matched_pred
),
axis
=-
1
)
# (B,)
loss_pq
=
(
alpha
*
loss_pos
+
(
1
-
alpha
)
*
loss_neg
)
/
n
loss_pq
=
tf
.
reduce_mean
(
loss_pq
)
loss_dict
[
"loss_pq"
]
=
loss_pq
@
gin
.
configurable
def
_instance_discrimination_loss
(
loss_dict
:
Dict
[
str
,
Any
],
labels
:
Dict
[
str
,
Any
],
outputs
:
Dict
[
str
,
Any
],
tau
:
float
=
gin
.
REQUIRED
):
"""Instance discrimination loss.
This method adds the ID loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
tau: The temperature term in the loss
"""
# The normalized feature, shape=(B, H/4, W/4, D)
g
=
outputs
[
"max_deep_lab"
][
"pixel_space_normalized_feature"
]
b
,
h
,
w
=
utilities
.
resolve_shape
(
g
)[:
3
]
# The ground-truth masks, shape=(B, N, H, W) --> (B, N, H/4, W/4)
m
=
labels
[
"masks"
]
m
=
tf
.
image
.
resize
(
tf
.
transpose
(
m
,
(
0
,
2
,
3
,
1
)),
(
h
,
w
),
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
m
=
tf
.
transpose
(
m
,
(
0
,
3
,
1
,
2
))
# The number of ground-truth instance (K), shape=(B,)
num
=
labels
[
"num_instance"
]
n
=
utilities
.
resolve_shape
(
m
)[
1
]
# max number of predictions
# is_void[b, i] = 1 if instance i in batch b is a padded slot.
is_void
=
tf
.
cast
(
tf
.
expand_dims
(
tf
.
range
(
n
),
0
),
tf
.
float32
)
# (1, n)
is_void
=
tf
.
cast
(
tf
.
math
.
greater_equal
(
is_void
,
tf
.
expand_dims
(
num
,
1
)),
tf
.
float32
)
# (B, N, D)
t
=
tf
.
math
.
l2_normalize
(
tf
.
einsum
(
"bhwd,bnhw->bnd"
,
g
,
m
),
axis
=-
1
)
inst_dist_logits
=
tf
.
einsum
(
"bhwd,bid->bhwi"
,
g
,
t
)
/
tau
# (B, H, W, N)
inst_dist_logits
=
inst_dist_logits
-
100.
*
tf
.
reshape
(
is_void
,
(
b
,
1
,
1
,
n
))
mask_id
=
tf
.
cast
(
tf
.
einsum
(
"bnhw,n->bhw"
,
m
,
tf
.
range
(
n
,
dtype
=
tf
.
float32
)),
tf
.
int32
)
loss_map
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
mask_id
,
logits
=
inst_dist_logits
)
# B, H, W
valid_mask
=
tf
.
reduce_sum
(
m
,
axis
=
1
)
loss_inst_dist
=
(
(
tf
.
reduce_sum
(
loss_map
*
valid_mask
,
axis
=
[
1
,
2
])
+
EPSILON
)
/
(
tf
.
reduce_sum
(
valid_mask
,
axis
=
[
1
,
2
])
+
EPSILON
))
loss_dict
[
"loss_inst_dist"
]
=
tf
.
reduce_mean
(
loss_inst_dist
)
@
gin
.
configurable
def
_paragraph_grouping_loss
(
loss_dict
:
Dict
[
str
,
Any
],
labels
:
Dict
[
str
,
Any
],
outputs
:
Dict
[
str
,
Any
],
tau
:
float
=
gin
.
REQUIRED
,
loss_mode
=
"vanilla"
,
fl_alpha
:
float
=
0.25
,
fl_gamma
:
float
=
2.
,
):
"""Instance discrimination loss.
This method adds the para discrimination loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
tau: The temperature term in the loss
loss_mode: The type of loss.
fl_alpha: alpha value in focal loss
fl_gamma: gamma value in focal loss
"""
if
"paragraph_labels"
not
in
labels
:
loss_dict
[
"loss_para"
]
=
0.
return
# step 1:
# obtain the paragraph labels for each prediction
# (batch, pred, gt)
matched_matrix
=
outputs
[
"instance_output"
][
"matched_mask"
]
# B, N, N
para_label_gt
=
labels
[
"paragraph_labels"
][
"paragraph_ids"
]
# B, N
has_para_label_gt
=
(
labels
[
"paragraph_labels"
][
"has_para_ids"
][:,
tf
.
newaxis
,
tf
.
newaxis
])
# '0' means no paragraph labels
pred_label_gt
=
tf
.
einsum
(
"bij,bj->bi"
,
matched_matrix
,
tf
.
cast
(
para_label_gt
+
1
,
tf
.
float32
))
pred_label_gt_pad_col
=
tf
.
expand_dims
(
pred_label_gt
,
-
1
)
# b,n,1
pred_label_gt_pad_row
=
tf
.
expand_dims
(
pred_label_gt
,
1
)
# b,1,n
gt_affinity
=
tf
.
cast
(
tf
.
equal
(
pred_label_gt_pad_col
,
pred_label_gt_pad_row
),
tf
.
float32
)
gt_affinity_mask
=
(
has_para_label_gt
*
pred_label_gt_pad_col
*
pred_label_gt_pad_row
)
gt_affinity_mask
=
tf
.
cast
(
tf
.
not_equal
(
gt_affinity_mask
,
0.
),
tf
.
float32
)
# step 2:
# get affinity matrix
affinity
=
outputs
[
"para_affinity"
]
# step 3:
# compute loss
loss_fn
=
tf
.
keras
.
losses
.
BinaryCrossentropy
(
from_logits
=
True
,
label_smoothing
=
0
,
axis
=-
1
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
NONE
,
name
=
"para_dist"
)
affinity
=
tf
.
reshape
(
affinity
,
(
-
1
,
1
))
# (b*n*n, 1)
gt_affinity
=
tf
.
reshape
(
gt_affinity
,
(
-
1
,
1
))
# (b*n*n, 1)
gt_affinity_mask
=
tf
.
reshape
(
gt_affinity_mask
,
(
-
1
,))
# (b*n*n,)
pointwise_loss
=
loss_fn
(
gt_affinity
,
affinity
/
tau
)
# (b*n*n,)
if
loss_mode
==
"vanilla"
:
loss
=
(
tf
.
reduce_sum
(
pointwise_loss
*
gt_affinity_mask
)
/
(
tf
.
reduce_sum
(
gt_affinity_mask
)
+
EPSILON
))
elif
loss_mode
==
"balanced"
:
# pos
pos_mask
=
gt_affinity_mask
*
gt_affinity
[:,
0
]
pos_loss
=
(
tf
.
reduce_sum
(
pointwise_loss
*
pos_mask
)
/
(
tf
.
reduce_sum
(
pos_mask
)
+
EPSILON
))
# neg
neg_mask
=
gt_affinity_mask
*
(
1.
-
gt_affinity
[:,
0
])
neg_loss
=
(
tf
.
reduce_sum
(
pointwise_loss
*
neg_mask
)
/
(
tf
.
reduce_sum
(
neg_mask
)
+
EPSILON
))
loss
=
0.25
*
pos_loss
+
0.75
*
neg_loss
elif
loss_mode
==
"focal"
:
alpha_wt
=
fl_alpha
*
gt_affinity
+
(
1.
-
fl_alpha
)
*
(
1.
-
gt_affinity
)
prob_pos
=
tf
.
math
.
sigmoid
(
affinity
/
tau
)
pt
=
prob_pos
*
gt_affinity
+
(
1.
-
prob_pos
)
*
(
1.
-
gt_affinity
)
fl_loss_pw
=
tf
.
stop_gradient
(
alpha_wt
*
tf
.
pow
(
1.
-
pt
,
fl_gamma
))[:,
0
]
*
pointwise_loss
loss
=
(
tf
.
reduce_sum
(
fl_loss_pw
*
gt_affinity_mask
)
/
(
tf
.
reduce_sum
(
gt_affinity_mask
)
+
EPSILON
))
else
:
raise
ValueError
(
f
"Not supported loss mode:
{
loss_mode
}
"
)
loss_dict
[
"loss_para"
]
=
loss
def
_mask_id_xent_loss
(
loss_dict
:
Dict
[
str
,
Any
],
labels
:
Dict
[
str
,
Any
],
outputs
:
Dict
[
str
,
Any
]):
"""Mask ID loss.
This method adds the mask ID loss term to loss_dict directly.
Args:
loss_dict: A dictionary for the loss. The values are loss scalars.
labels: The label dictionary.
outputs: The output dictionary.
"""
# (B, N, H, W)
mask_gt
=
labels
[
"masks"
]
# B, H, W, N
mask_id_logits
=
outputs
[
"instance_output"
][
"mask_id_logits"
]
# B, N, N
matched_matrix
=
outputs
[
"instance_output"
][
"matched_mask"
]
# B, N
gt_to_pred_id
=
tf
.
cast
(
tf
.
math
.
argmax
(
matched_matrix
,
axis
=
1
),
tf
.
float32
)
# B, H, W
mask_id_labels
=
tf
.
cast
(
tf
.
einsum
(
"bnhw,bn->bhw"
,
mask_gt
,
gt_to_pred_id
),
tf
.
int32
)
loss_map
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
mask_id_labels
,
logits
=
mask_id_logits
)
valid_mask
=
tf
.
reduce_sum
(
mask_gt
,
axis
=
1
)
loss_mask_id
=
(
(
tf
.
reduce_sum
(
loss_map
*
valid_mask
,
axis
=
[
1
,
2
])
+
EPSILON
)
/
(
tf
.
reduce_sum
(
valid_mask
,
axis
=
[
1
,
2
])
+
EPSILON
))
loss_dict
[
"loss_mask_id"
]
=
tf
.
reduce_mean
(
loss_mask_id
)
official/projects/unified_detector/registry_imports.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official.projects.unified_detector
import
external_configurables
from
official.projects.unified_detector.configs
import
ocr_config
from
official.projects.unified_detector.tasks
import
ocr_task
from
official.vision
import
registry_imports
official/projects/unified_detector/requirements.txt
0 → 100644
View file @
b9e00ebf
tf-nightly
gin-config
opencv-python==4.1.2.30
absl-py>=1.0.0
shapely>=1.8.1
apache_beam>=2.37.0
matplotlib>=3.5.1
notebook>=6.4.10
official/projects/unified_detector/run_inference.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""A binary to run unified detector."""
import
json
import
os
from
typing
import
Any
,
Dict
,
Sequence
,
Union
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
cv2
import
gin
import
numpy
as
np
import
tensorflow
as
tf
import
tqdm
from
official.projects.unified_detector
import
external_configurables
# pylint: disable=unused-import
from
official.projects.unified_detector.modeling
import
universal_detector
from
official.projects.unified_detector.utils
import
utilities
# group two lines into a paragraph if affinity score higher than this
_PARA_GROUP_THR
=
0.5
# MODEL spec
_GIN_FILE
=
flags
.
DEFINE_string
(
'gin_file'
,
None
,
'Path to the Gin file that defines the model.'
)
_CKPT_PATH
=
flags
.
DEFINE_string
(
'ckpt_path'
,
None
,
'Path to the checkpoint directory.'
)
_IMG_SIZE
=
flags
.
DEFINE_integer
(
'img_size'
,
1024
,
'Size of the image fed to the model.'
)
# Input & Output
# Note that, all images specified by `img_file` and `img_dir` will be processed.
_IMG_FILE
=
flags
.
DEFINE_multi_string
(
'img_file'
,
[],
'Paths to the images.'
)
_IMG_DIR
=
flags
.
DEFINE_multi_string
(
'img_dir'
,
[],
'Paths to the image directories.'
)
_OUTPUT_PATH
=
flags
.
DEFINE_string
(
'output_path'
,
None
,
'Path for the output.'
)
_VIS_DIR
=
flags
.
DEFINE_string
(
'vis_dir'
,
None
,
'Path for the visualization output.'
)
def
_preprocess
(
raw_image
:
np
.
ndarray
)
->
Union
[
np
.
ndarray
,
float
]:
"""Convert a raw image to properly resized, padded, and normalized ndarray."""
# (1) convert to tf.Tensor and float32.
img_tensor
=
tf
.
convert_to_tensor
(
raw_image
,
dtype
=
tf
.
float32
)
# (2) pad to square.
height
,
width
=
img_tensor
.
shape
[:
2
]
maximum_side
=
tf
.
maximum
(
height
,
width
)
height_pad
=
maximum_side
-
height
width_pad
=
maximum_side
-
width
img_tensor
=
tf
.
pad
(
img_tensor
,
[[
0
,
height_pad
],
[
0
,
width_pad
],
[
0
,
0
]],
constant_values
=
127
)
ratio
=
maximum_side
/
_IMG_SIZE
.
value
# (3) resize long side to the maximum length.
img_tensor
=
tf
.
image
.
resize
(
img_tensor
,
(
_IMG_SIZE
.
value
,
_IMG_SIZE
.
value
))
img_tensor
=
tf
.
cast
(
img_tensor
,
tf
.
uint8
)
# (4) normalize
img_tensor
=
utilities
.
normalize_image_to_range
(
img_tensor
)
# (5) Add batch dimension and return as numpy array.
return
tf
.
expand_dims
(
img_tensor
,
0
).
numpy
(),
float
(
ratio
)
def
load_model
()
->
tf
.
keras
.
layers
.
Layer
:
gin
.
parse_config_file
(
_GIN_FILE
.
value
)
model
=
universal_detector
.
UniversalDetector
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
_CKPT_PATH
.
value
logging
.
info
(
'Load ckpt from: %s'
,
ckpt_path
)
ckpt
.
restore
(
ckpt_path
).
expect_partial
()
return
model
def
inference
(
img_file
:
str
,
model
:
tf
.
keras
.
layers
.
Layer
)
->
Dict
[
str
,
Any
]:
"""Inference step."""
img
=
cv2
.
cvtColor
(
cv2
.
imread
(
img_file
),
cv2
.
COLOR_BGR2RGB
)
img_ndarray
,
ratio
=
_preprocess
(
img
)
output_dict
=
model
.
serve
(
img_ndarray
)
class_tensor
=
output_dict
[
'classes'
].
numpy
()
mask_tensor
=
output_dict
[
'masks'
].
numpy
()
group_tensor
=
output_dict
[
'groups'
].
numpy
()
indices
=
np
.
where
(
class_tensor
[
0
])[
0
].
tolist
()
# indices of positive slots.
mask_list
=
[
mask_tensor
[
0
,
:,
:,
index
]
for
index
in
indices
]
# List of mask ndarray.
# Form lines and words
lines
=
[]
line_indices
=
[]
for
index
,
mask
in
tqdm
.
tqdm
(
zip
(
indices
,
mask_list
)):
line
=
{
'words'
:
[],
'text'
:
''
,
}
contours
,
_
=
cv2
.
findContours
(
(
mask
>
0.
).
astype
(
np
.
uint8
),
cv2
.
RETR_TREE
,
cv2
.
CHAIN_APPROX_SIMPLE
)[
-
2
:]
for
contour
in
contours
:
if
(
isinstance
(
contour
,
np
.
ndarray
)
and
len
(
contour
.
shape
)
==
3
and
contour
.
shape
[
0
]
>
2
and
contour
.
shape
[
1
]
==
1
and
contour
.
shape
[
2
]
==
2
):
cnt_list
=
(
contour
[:,
0
]
*
ratio
).
astype
(
np
.
int32
).
tolist
()
line
[
'words'
].
append
({
'text'
:
''
,
'vertices'
:
cnt_list
})
else
:
logging
.
error
(
'Invalid contour: %s, discarded'
,
str
(
contour
))
if
line
[
'words'
]:
lines
.
append
(
line
)
line_indices
.
append
(
index
)
# Form paragraphs
line_grouping
=
utilities
.
DisjointSet
(
len
(
line_indices
))
affinity
=
group_tensor
[
0
][
line_indices
][:,
line_indices
]
for
i1
,
i2
in
zip
(
*
np
.
where
(
affinity
>
_PARA_GROUP_THR
)):
line_grouping
.
union
(
i1
,
i2
)
line_groups
=
line_grouping
.
to_group
()
paragraphs
=
[]
for
line_group
in
line_groups
:
paragraph
=
{
'lines'
:
[]}
for
id_
in
line_group
:
paragraph
[
'lines'
].
append
(
lines
[
id_
])
if
paragraph
:
paragraphs
.
append
(
paragraph
)
return
paragraphs
def
main
(
argv
:
Sequence
[
str
])
->
None
:
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
# Get list of images
img_lists
=
[]
img_lists
.
extend
(
_IMG_FILE
.
value
)
for
img_dir
in
_IMG_DIR
.
value
:
img_lists
.
extend
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
img_dir
,
'*'
)))
logging
.
info
(
'Total number of input images: %d'
,
len
(
img_lists
))
model
=
load_model
()
vis_dis
=
_VIS_DIR
.
value
output
=
{
'annotations'
:
[]}
for
img_file
in
tqdm
.
tqdm
(
img_lists
):
output
[
'annotations'
].
append
({
'image_id'
:
img_file
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
],
'paragraphs'
:
inference
(
img_file
,
model
),
})
if
vis_dis
:
key
=
output
[
'annotations'
][
-
1
][
'image_id'
]
paragraphs
=
output
[
'annotations'
][
-
1
][
'paragraphs'
]
img
=
cv2
.
cvtColor
(
cv2
.
imread
(
img_file
),
cv2
.
COLOR_BGR2RGB
)
word_bnds
=
[]
line_bnds
=
[]
para_bnds
=
[]
for
paragraph
in
paragraphs
:
paragraph_points_list
=
[]
for
line
in
paragraph
[
'lines'
]:
line_points_list
=
[]
for
word
in
line
[
'words'
]:
word_bnds
.
append
(
np
.
array
(
word
[
'vertices'
],
np
.
int32
).
reshape
((
-
1
,
1
,
2
)))
line_points_list
.
extend
(
word
[
'vertices'
])
paragraph_points_list
.
extend
(
line_points_list
)
line_points
=
np
.
array
(
line_points_list
,
np
.
int32
)
# (N,2)
left
=
int
(
np
.
min
(
line_points
[:,
0
]))
top
=
int
(
np
.
min
(
line_points
[:,
1
]))
right
=
int
(
np
.
max
(
line_points
[:,
0
]))
bottom
=
int
(
np
.
max
(
line_points
[:,
1
]))
line_bnds
.
append
(
np
.
array
([[[
left
,
top
]],
[[
right
,
top
]],
[[
right
,
bottom
]],
[[
left
,
bottom
]]],
np
.
int32
))
para_points
=
np
.
array
(
paragraph_points_list
,
np
.
int32
)
# (N,2)
left
=
int
(
np
.
min
(
para_points
[:,
0
]))
top
=
int
(
np
.
min
(
para_points
[:,
1
]))
right
=
int
(
np
.
max
(
para_points
[:,
0
]))
bottom
=
int
(
np
.
max
(
para_points
[:,
1
]))
para_bnds
.
append
(
np
.
array
([[[
left
,
top
]],
[[
right
,
top
]],
[[
right
,
bottom
]],
[[
left
,
bottom
]]],
np
.
int32
))
for
name
,
bnds
in
zip
([
'paragraph'
,
'line'
,
'word'
],
[
para_bnds
,
line_bnds
,
word_bnds
]):
vis
=
cv2
.
polylines
(
img
,
bnds
,
True
,
(
0
,
0
,
255
),
2
)
cv2
.
imwrite
(
os
.
path
.
join
(
vis_dis
,
f
'
{
key
}
-
{
name
}
.jpg'
),
cv2
.
cvtColor
(
vis
,
cv2
.
COLOR_RGB2BGR
))
with
tf
.
io
.
gfile
.
GFile
(
_OUTPUT_PATH
.
value
,
mode
=
'w'
)
as
f
:
f
.
write
(
json
.
dumps
(
output
,
ensure_ascii
=
False
,
indent
=
2
))
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'gin_file'
,
'ckpt_path'
,
'output_path'
])
app
.
run
(
main
)
official/projects/unified_detector/tasks/all_models.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import all models.
All model files are imported here so that they can be referenced in Gin. Also,
importing here avoids making ocr_task.py too messy.
"""
# pylint: disable=unused-import
from
official.projects.unified_detector.modeling
import
universal_detector
official/projects/unified_detector/tasks/ocr_task.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for ocr."""
from
typing
import
Callable
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Union
import
gin
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.projects.unified_detector.configs
import
ocr_config
from
official.projects.unified_detector.data_loaders
import
input_reader
from
official.projects.unified_detector.tasks
import
all_models
# pylint: disable=unused-import
from
official.projects.unified_detector.utils
import
typing
NestedTensorDict
=
typing
.
NestedTensorDict
ModelType
=
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
]
@
task_factory
.
register_task_cls
(
ocr_config
.
OcrTaskConfig
)
@
gin
.
configurable
class
OcrTask
(
base_task
.
Task
):
"""Defining the OCR training task."""
_loss_items
=
[]
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
model_fn
:
Callable
[...,
ModelType
]
=
gin
.
REQUIRED
):
super
().
__init__
(
params
,
logging_dir
,
name
)
self
.
_modef_fn
=
model_fn
def
build_model
(
self
)
->
ModelType
:
"""Build and return the model, record the loss items as well."""
model
=
self
.
_modef_fn
()
self
.
_loss_items
.
extend
(
model
.
loss_items
)
return
model
def
build_inputs
(
self
,
params
:
cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Build the tf.data.Dataset instance."""
return
input_reader
.
InputFn
(
is_training
=
params
.
is_training
)({},
input_context
)
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
Sequence
[
tf
.
keras
.
metrics
.
Metric
]:
"""Build the metrics (currently, only for loss summaries in TensorBoard)."""
del
training
metrics
=
[]
# Add loss items
for
name
in
self
.
_loss_items
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
# TODO(longshangbang): add evaluation metrics
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
NestedTensorDict
,
NestedTensorDict
],
model
:
ModelType
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
Sequence
[
tf
.
keras
.
metrics
.
Metric
]]
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
features
,
labels
=
inputs
input_dict
=
{
"features"
:
features
}
if
self
.
task_config
.
model_call_needs_labels
:
input_dict
[
"labels"
]
=
labels
is_mixed_precision
=
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
)
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
**
input_dict
,
training
=
True
)
loss
,
loss_dict
=
model
.
compute_losses
(
labels
=
labels
,
outputs
=
outputs
)
loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
if
is_mixed_precision
:
loss
=
optimizer
.
get_scaled_loss
(
loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
loss
,
tvars
)
if
is_mixed_precision
:
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
"loss"
:
loss
}
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
loss_dict
[
m
.
name
])
return
logs
official/projects/unified_detector/train.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
# pylint: disable=unused-import
from
official.projects.unified_detector
import
registry_imports
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/projects/unified_detector/utils/typing.py
0 → 100644
View file @
b9e00ebf
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Typing extension."""
from
typing
import
Dict
,
Union
import
numpy
as
np
import
tensorflow
as
tf
NpDict
=
Dict
[
str
,
np
.
ndarray
]
FeaturesAndLabelsType
=
Dict
[
str
,
Dict
[
str
,
tf
.
Tensor
]]
TensorDict
=
Dict
[
Union
[
str
,
int
],
tf
.
Tensor
]
NestedTensorDict
=
Dict
[
Union
[
str
,
int
],
Union
[
tf
.
Tensor
,
TensorDict
]]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment