Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ac1f5735
Commit
ac1f5735
authored
Jun 02, 2022
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 02, 2022
Browse files
Internal change
PiperOrigin-RevId: 452578619
parent
52902342
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3893 additions
and
0 deletions
+3893
-0
official/projects/unified_detector/README.md
official/projects/unified_detector/README.md
+163
-0
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
...ied_detector/configs/gin_files/unified_detector_model.gin
+43
-0
official/projects/unified_detector/configs/gin_files/unified_detector_train.gin
...ied_detector/configs/gin_files/unified_detector_train.gin
+22
-0
official/projects/unified_detector/configs/ocr_config.py
official/projects/unified_detector/configs/ocr_config.py
+78
-0
official/projects/unified_detector/data_conversion/convert.py
...cial/projects/unified_detector/data_conversion/convert.py
+66
-0
official/projects/unified_detector/data_conversion/utils.py
official/projects/unified_detector/data_conversion/utils.py
+182
-0
official/projects/unified_detector/data_loaders/autoaugment.py
...ial/projects/unified_detector/data_loaders/autoaugment.py
+753
-0
official/projects/unified_detector/data_loaders/input_reader.py
...al/projects/unified_detector/data_loaders/input_reader.py
+270
-0
official/projects/unified_detector/data_loaders/tf_example_decoder.py
...jects/unified_detector/data_loaders/tf_example_decoder.py
+320
-0
official/projects/unified_detector/data_loaders/universal_detection_parser.py
...ified_detector/data_loaders/universal_detection_parser.py
+606
-0
official/projects/unified_detector/docs/images/task.png
official/projects/unified_detector/docs/images/task.png
+0
-0
official/projects/unified_detector/external_configurables.py
official/projects/unified_detector/external_configurables.py
+22
-0
official/projects/unified_detector/modeling/universal_detector.py
.../projects/unified_detector/modeling/universal_detector.py
+888
-0
official/projects/unified_detector/registry_imports.py
official/projects/unified_detector/registry_imports.py
+21
-0
official/projects/unified_detector/requirements.txt
official/projects/unified_detector/requirements.txt
+8
-0
official/projects/unified_detector/run_inference.py
official/projects/unified_detector/run_inference.py
+222
-0
official/projects/unified_detector/tasks/all_models.py
official/projects/unified_detector/tasks/all_models.py
+23
-0
official/projects/unified_detector/tasks/ocr_task.py
official/projects/unified_detector/tasks/ocr_task.py
+108
-0
official/projects/unified_detector/train.py
official/projects/unified_detector/train.py
+70
-0
official/projects/unified_detector/utils/typing.py
official/projects/unified_detector/utils/typing.py
+28
-0
No files found.
official/projects/unified_detector/README.md
0 → 100644
View file @
ac1f5735
# Towards End-to-End Unified Scene Text Detection and Layout Analysis

[

](https://arxiv.org/abs/2203.15143)
Official TensorFlow 2 implementation of the paper
`Towards End-to-End Unified
Scene Text Detection and Layout Analysis`
. If you encounter any issues using the
code, you are welcome to submit them to the Issues tab or send emails directly
to us:
`hiertext@google.com`
.
## Installation
### Set up TensorFlow Models
```
bash
# (Optional) Create and enter a virtual environment
pip3
install
--user
virtualenv
virtualenv
-p
python3 unified_detector
source
./unified_detector/bin/activate
# First clone the TensorFlow Models project:
git clone https://github.com/tensorflow/models.git
# Install the requirements of TensorFlow Models and this repo:
cd
models
pip3
install
-r
official/requirements.txt
pip3
install
-r
official/projects/unified_detector/requirements.txt
# Compile the protos
# If `protoc` is not installed, please follow: https://grpc.io/docs/protoc-installation/
export
PYTHONPATH
=
${
PYTHONPATH
}
:
${
PWD
}
/research/
cd
research/object_detection/
protoc protos/string_int_label_map.proto
--python_out
=
.
```
### Set up Deeplab2
```
bash
# Clone Deeplab2 anywhere you like
cd
<somewhere>
git clone https://github.com/google-research/deeplab2.git
# Compile the protos
protoc deeplab2/
*
.proto
--python_out
=
.
# Add to PYTHONPATH the directory where deeplab2 sits.
export
PYTHONPATH
=
${
PYTHONPATH
}
:
${
PWD
}
```
## Running the model on some images using the provided checkpoint.
### Download the checkpoint
Model | Input Resolution | #object query | line PQ (val) | paragraph PQ (val) | line PQ (test) | paragraph PQ (test)
---------------------------------------------------------------------------------------------------------------------------------- | ---------------- | ------------- | ------------- | ------------------ | -------------- | -------------------
Unified-Detector-Line (
[
ckpt
](
https://storage.cloud.google.com/tf_model_garden/vision/unified_detector/unified_detector_ckpt.tgz
)
) | 1024 | 384 | 61.04 | 52.84 | 62.20 | 53.52
### Demo on single images
```
bash
# run from `models/`
python3
-m
official.projects.unified_detector.run_inference
\
--gin_file
=
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
\
--ckpt_path
=
<path-of-the-ckpt>
\
--img_file
=
<some-image>
\
--output_path
=
<some-directory>/demo.jsonl
\
--vis_dir
=
<some-directory>
```
The output will be stored in jsonl in the same hierarchical format as required
by the evaluation script of the HierText dataset. There will also be
visualizations of the word/line/paragraph boundaries. Note that, the unified
detector produces line-level masks and an affinity matrix for grouping lines
into paragraphs. For visualization purpose, we split each line mask into pixel
groups which are defined as connected components/pixels. We visualize these
groups as
`words`
. They are not necessarily at the word granularity, though. We
visualize lines and paragraphs as groupings of these
`words`
using axis-aligned
bounding boxes.
## Inference and Evaluation on the HierText dataset
### Download the HierText dataset
Clone the
[
HierText repo
](
https://github.com/google-research-datasets/hiertext
)
and download the dataset. The
`requirements.txt`
in this folder already covers
those in the HierText repo, so there is no need to create a new virtual
environment again.
### Inference and eval
The following command will run the model on the validation set and compute the
score. Note that the test set annotation is not released yet, so only validation
set is used here for demo purposes.
#### Inference
```
bash
# Run from `models/`
python3
-m
official.projects.unified_detector.run_inference
\
--gin_file
=
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
\
--ckpt_path
=
<path-of-the-ckpt>
\
--img_dir
=
<the-directory-containing-validation-images>
\
--output_path
=
<some-directory>/validation_output.jsonl
```
#### Evaluation
```
bash
# Run from `hiertext/`
python3 eval.py
\
--gt
=
gt/validation.jsonl
\
--result
=
<some-directory>/validation_output.jsonl
\
--output
=
./validation-score.txt
\
--mask_stride
=
1
\
--eval_lines
\
--eval_paragraphs
\
--num_workers
=
0
```
## Train new models.
First, you will need to convert the HierText dataset into TFrecords:
```
bash
# Run from `models/official/projects/unified_detector/data_conversion`
CUDA_VISIBLE_DEVICES
=
''
python3 convert.py
\
--gt_file
=
/path/to/gt.jsonl
\
--img_dir
=
/path/to/image
\
--out_file
=
/path/to/tfrecords/file-prefix
```
To train the unified detector, run the following script:
```
bash
# Run from `models/`
python3
-m
official.projects.unified_detector.train
\
--mode
=
train
\
--experiment
=
unified_detector
\
--model_dir
=
'<some path>'
\
--gin_file
=
'official/projects/unified_detector/configs/gin_files/unified_detector_train.gin'
\
--gin_file
=
'official/projects/unified_detector/configs/gin_files/unified_detector_model.gin'
\
--gin_params
=
'InputFn.input_paths = ["/path/to/tfrecords/file-prefix*"]'
```
## Citation
Please cite our
[
paper
](
https://arxiv.org/pdf/2203.15143.pdf
)
if you find this
work helpful:
```
@inproceedings{long2022towards,
title={Towards End-to-End Unified Scene Text Detection and Layout Analysis},
author={Long, Shangbang and Qin, Siyang and Panteleev, Dmitry and Bissacco, Alessandro and Fujii, Yasuhisa and Raptis, Michalis},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
```
official/projects/unified_detector/configs/gin_files/unified_detector_model.gin
0 → 100644
View file @
ac1f5735
# Defining the unified detector models.
# Model
## Backbone
num_slots = 384
SyncBatchNormalization.momentum = 0.95
get_max_deep_lab_backbone.num_slots = %num_slots
## Decoder
intermediate_filters = 256
num_entity_class = 3 # C + 1 (bkg) + 1 (void)
_get_decoder_head.atrous_rates = (6, 12, 18)
_get_decoder_head.pixel_space_dim = 128
_get_decoder_head.pixel_space_intermediate = %intermediate_filters
_get_decoder_head.num_classes = %num_entity_class
_get_decoder_head.aux_sem_intermediate = %intermediate_filters
_get_decoder_head.low_level = [
{'feature_key': 'res3', 'channels_project': 64,},
{'feature_key': 'res2', 'channels_project': 32,},]
_get_decoder_head.norm_fn = @SyncBatchNormalization
_get_embed_head.norm_fn = @LayerNorm
# Loss
# pq loss
alpha = 0.75
tau = 0.3
_entity_mask_loss.alpha = %alpha
_instance_discrimination_loss.tau = %tau
_paragraph_grouping_loss.tau = %tau
_paragraph_grouping_loss.loss_mode = 'balanced'
# Other Model setting
UniversalDetector.mask_threshold = 0.4
UniversalDetector.class_threshold = 0.5
UniversalDetector.filter_area = 32
universal_detection_loss_weights.loss_segmentation_word = 1e0
universal_detection_loss_weights.loss_inst_dist = 1e0
universal_detection_loss_weights.loss_mask_id = 1e-4
universal_detection_loss_weights.loss_pq = 3e0
universal_detection_loss_weights.loss_para = 1e0
official/projects/unified_detector/configs/gin_files/unified_detector_train.gin
0 → 100644
View file @
ac1f5735
# Defining the input pipeline of unified detector.
# ===== ===== Model ===== =====
# Internal import 2.
OcrTask.model_fn = @UniversalDetector
# ===== ===== Data pipeline ===== =====
InputFn.parser_fn = @UniDetectorParserFn
InputFn.dataset_type = 'tfrecord'
InputFn.batch_size = 256
# Internal import 3.
UniDetectorParserFn.output_dimension = 1024
# Simple data augmentation for now.
UniDetectorParserFn.rot90_probability = 0.0
UniDetectorParserFn.use_color_distortion = True
UniDetectorParserFn.crop_min_scale = 0.5
UniDetectorParserFn.crop_max_scale = 1.5
UniDetectorParserFn.crop_min_aspect = 0.8
UniDetectorParserFn.crop_max_aspect = 1.25
UniDetectorParserFn.max_num_instance = 384
official/projects/unified_detector/configs/ocr_config.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OCR tasks and models configurations."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
@
dataclasses
.
dataclass
class
OcrTaskConfig
(
cfg
.
TaskConfig
):
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
model_call_needs_labels
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'unified_detector'
)
def
unified_detector
()
->
cfg
.
ExperimentConfig
:
"""Configurations for trainer of unified detector."""
total_train_steps
=
100000
summary_interval
=
steps_per_loop
=
200
checkpoint_interval
=
2000
warmup_steps
=
1000
config
=
cfg
.
ExperimentConfig
(
# Input pipeline and model are configured through Gin.
task
=
OcrTaskConfig
(
train_data
=
cfg
.
DataConfig
(
is_training
=
True
)),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
total_train_steps
,
steps_per_loop
=
steps_per_loop
,
summary_interval
=
summary_interval
,
checkpoint_interval
=
checkpoint_interval
,
max_to_keep
=
1
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
[
'^((?!depthwise).)*(kernel|weights):0$'
,
],
'exclude_from_weight_decay'
:
[
'(^((?!kernel).)*:0)|(depthwise_kernel)'
,
],
'gradient_clip_norm'
:
10.
,
},
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
1e-3
,
'decay_steps'
:
total_train_steps
-
warmup_steps
,
'alpha'
:
1e-2
,
'offset'
:
warmup_steps
,
},
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_learning_rate'
:
1e-5
,
'warmup_steps'
:
warmup_steps
,
}
},
}),
),
)
return
config
official/projects/unified_detector/data_conversion/convert.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Script to convert HierText to TFExamples.
This script is only intended to run locally.
python3 data_preprocess/convert.py \
--gt_file=/path/to/gt.jsonl \
--img_dir=/path/to/image \
--out_file=/path/to/tfrecords/file-prefix
"""
import
json
import
os
import
random
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
import
tqdm
import
utils
_GT_FILE
=
flags
.
DEFINE_string
(
'gt_file'
,
None
,
'Path to the GT file'
)
_IMG_DIR
=
flags
.
DEFINE_string
(
'img_dir'
,
None
,
'Path to the image folder.'
)
_OUT_FILE
=
flags
.
DEFINE_string
(
'out_file'
,
None
,
'Path for the tfrecords.'
)
_NUM_SHARD
=
flags
.
DEFINE_integer
(
'num_shard'
,
100
,
'The number of shards of tfrecords.'
)
def
main
(
unused_argv
)
->
None
:
annotations
=
json
.
load
(
open
(
_GT_FILE
.
value
))[
'annotations'
]
random
.
shuffle
(
annotations
)
n_sample
=
len
(
annotations
)
n_shards
=
_NUM_SHARD
.
value
n_sample_per_shard
=
(
n_sample
-
1
)
//
n_shards
+
1
for
shard
in
tqdm
.
tqdm
(
range
(
n_shards
)):
output_path
=
f
'
{
_OUT_FILE
.
value
}
-
{
shard
:
05
}
-
{
n_shards
:
05
}
.tfrecords'
annotation_subset
=
annotations
[
shard
*
n_sample_per_shard
:
(
shard
+
1
)
*
n_sample_per_shard
]
with
tf
.
io
.
TFRecordWriter
(
output_path
)
as
file_writer
:
for
annotation
in
annotation_subset
:
img_file_path
=
os
.
path
.
join
(
_IMG_DIR
.
value
,
f
"
{
annotation
[
'image_id'
]
}
.jpg"
)
tfexample
=
utils
.
convert_to_tfe
(
img_file_path
,
annotation
)
file_writer
.
write
(
tfexample
)
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'gt_file'
,
'img_dir'
,
'out_file'
])
app
.
run
(
main
)
official/projects/unified_detector/data_conversion/utils.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to convert data to TFExamples and store in TFRecords."""
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
def
encode_image
(
image_tensor
:
np
.
ndarray
,
encoding_type
:
str
=
'png'
)
->
Union
[
np
.
ndarray
,
tf
.
Tensor
]:
"""Encode image tensor into byte string."""
if
encoding_type
==
'jpg'
:
image_encoded
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
))
elif
encoding_type
==
'png'
:
image_encoded
=
tf
.
image
.
encode_png
(
tf
.
constant
(
image_tensor
))
else
:
raise
ValueError
(
'Invalid encoding type.'
)
if
tf
.
executing_eagerly
():
image_encoded
=
image_encoded
.
numpy
()
else
:
image_encoded
=
image_encoded
.
eval
()
return
image_encoded
def
int64_feature
(
value
:
Union
[
int
,
List
[
int
]])
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
def
float_feature
(
value
:
Union
[
float
,
List
[
float
]])
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
def
bytes_feature
(
value
:
Union
[
Union
[
bytes
,
str
],
List
[
Union
[
bytes
,
str
]]]
)
->
tf
.
train
.
Feature
:
if
not
isinstance
(
value
,
list
):
value
=
[
value
]
for
i
in
range
(
len
(
value
)):
if
not
isinstance
(
value
[
i
],
bytes
):
value
[
i
]
=
value
[
i
].
encode
(
'utf-8'
)
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
value
))
def
annotation_to_entities
(
annotation
:
Dict
[
str
,
Any
])
->
List
[
Dict
[
str
,
Any
]]:
"""Flatten the annotation dict to a list of 'entities'."""
entities
=
[]
for
paragraph
in
annotation
[
'paragraphs'
]:
paragraph_id
=
len
(
entities
)
paragraph
[
'type'
]
=
3
# 3 for paragraph
paragraph
[
'parent_id'
]
=
-
1
entities
.
append
(
paragraph
)
for
line
in
paragraph
[
'lines'
]:
line_id
=
len
(
entities
)
line
[
'type'
]
=
2
# 2 for line
line
[
'parent_id'
]
=
paragraph_id
entities
.
append
(
line
)
for
word
in
line
[
'words'
]:
word
[
'type'
]
=
1
# 1 for word
word
[
'parent_id'
]
=
line_id
entities
.
append
(
word
)
return
entities
def
draw_entity_mask
(
entities
:
List
[
Dict
[
str
,
Any
]],
image_shape
:
Tuple
[
int
,
int
,
int
])
->
np
.
ndarray
:
"""Draw entity id mask.
Args:
entities: A list of entity objects. Should be output from
`annotation_to_entities`.
image_shape: The shape of the input image.
Returns:
A (H, W, 3) entity id mask of the same height/width as the image. Each pixel
(i, j, :) encodes the entity id of one pixel. Only word entities are
rendered. 0 for non-text pixels; word entity ids start from 1.
"""
instance_mask
=
np
.
zeros
(
image_shape
,
dtype
=
np
.
uint8
)
for
i
,
entity
in
enumerate
(
entities
):
# only draw word masks
if
entity
[
'type'
]
!=
1
:
continue
vertices
=
np
.
array
(
entity
[
'vertices'
])
# the pixel value is actually 1 + position in entities
entity_id
=
i
+
1
if
entity_id
>=
65536
:
# As entity_id is encoded in the last two channels, it should be less than
# 256**2=65536.
raise
ValueError
(
(
f
'Entity ID overflow:
{
entity_id
}
. Currently only entity_id<65536 '
'are supported.'
))
# use the last two channels to encode the entity id.
color
=
[
0
,
entity_id
//
256
,
entity_id
%
256
]
instance_mask
=
cv2
.
fillPoly
(
instance_mask
,
[
np
.
round
(
vertices
).
astype
(
'int32'
)],
color
)
return
instance_mask
def
convert_to_tfe
(
img_file_name
:
str
,
annotation
:
Dict
[
str
,
Any
])
->
tf
.
train
.
Example
:
"""Convert the annotation dict into a TFExample."""
img
=
cv2
.
imread
(
img_file_name
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
h
,
w
,
c
=
img
.
shape
encoded_img
=
encode_image
(
img
)
entities
=
annotation_to_entities
(
annotation
)
masks
=
draw_entity_mask
(
entities
,
img
.
shape
)
encoded_mask
=
encode_image
(
masks
)
# encode attributes
parent
=
[]
classes
=
[]
content_type
=
[]
text
=
[]
vertices
=
[]
for
entity
in
entities
:
parent
.
append
(
entity
[
'parent_id'
])
classes
.
append
(
entity
[
'type'
])
# 0 for annotated; 8 for not annotated
content_type
.
append
((
0
if
entity
[
'legible'
]
else
8
))
text
.
append
(
entity
.
get
(
'text'
,
''
))
v
=
np
.
array
(
entity
[
'vertices'
])
vertices
.
append
(
','
.
join
(
str
(
float
(
n
))
for
n
in
v
.
reshape
(
-
1
)))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
# input images
'image/encoded'
:
bytes_feature
(
encoded_img
),
# image format
'image/format'
:
bytes_feature
(
'png'
),
# image width
'image/width'
:
int64_feature
([
w
]),
# image height
'image/height'
:
int64_feature
([
h
]),
# image channels
'image/channels'
:
int64_feature
([
c
]),
# image key
'image/source_id'
:
bytes_feature
(
annotation
[
'image_id'
]),
# HxWx3 tensors: channel 2-3 encodes the id of the word entity.
'image/additional_channels/encoded'
:
bytes_feature
(
encoded_mask
),
# format of the additional channels
'image/additional_channels/format'
:
bytes_feature
(
'png'
),
'image/object/parent'
:
int64_feature
(
parent
),
# word / line / paragraph / symbol / ...
'image/object/classes'
:
int64_feature
(
classes
),
# text / handwritten / not-annotated / ...
'image/object/content_type'
:
int64_feature
(
content_type
),
# string text transcription
'image/object/text'
:
bytes_feature
(
text
),
# comma separated coordinates, (x,y) * n
'image/object/vertices'
:
bytes_feature
(
vertices
),
})).
SerializeToString
()
return
example
official/projects/unified_detector/data_loaders/autoaugment.py
0 → 100644
View file @
ac1f5735
This diff is collapsed.
Click to expand it.
official/projects/unified_detector/data_loaders/input_reader.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input data reader.
Creates a tf.data.Dataset object from multiple input sstables and use a
provided data parser function to decode the serialized tf.Example and optionally
run data augmentation.
"""
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Union
import
gin
from
six.moves
import
map
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
research.object_detection.utils
import
label_map_util
from
official.core
import
config_definitions
as
cfg
from
official.projects.unified_detector.data_loaders
import
universal_detection_parser
# pylint: disable=unused-import
FuncType
=
Callable
[...,
Any
]
@
gin
.
configurable
(
denylist
=
[
'is_training'
])
class
InputFn
(
object
):
"""Input data reader class.
Creates a tf.data.Dataset object from multiple datasets (optionally performs
weighted sampling between different datasets), parses the tf.Example message
using `parser_fn`. The datasets can either be stored in SSTable or TfRecord.
"""
def
__init__
(
self
,
is_training
:
bool
,
batch_size
:
Optional
[
int
]
=
None
,
data_root
:
str
=
''
,
input_paths
:
List
[
str
]
=
gin
.
REQUIRED
,
dataset_type
:
str
=
'tfrecord'
,
use_sampling
:
bool
=
False
,
sampling_weights
:
Optional
[
Sequence
[
Union
[
int
,
float
]]]
=
None
,
cycle_length
:
Optional
[
int
]
=
64
,
shuffle_buffer_size
:
Optional
[
int
]
=
512
,
parser_fn
:
Optional
[
FuncType
]
=
None
,
parser_num_parallel_calls
:
Optional
[
int
]
=
64
,
max_intra_op_parallelism
:
Optional
[
int
]
=
None
,
label_map_proto_path
:
Optional
[
str
]
=
None
,
input_filter_fns
:
Optional
[
List
[
FuncType
]]
=
None
,
input_training_filter_fns
:
Optional
[
Sequence
[
FuncType
]]
=
None
,
dense_to_ragged_batch
:
bool
=
False
,
data_validator_fn
:
Optional
[
Callable
[[
Sequence
[
str
]],
None
]]
=
None
):
"""Input reader constructor.
Args:
is_training: Boolean indicating TRAIN or EVAL.
batch_size: Input data batch size. Ignored if batch size is passed through
params. In that case, this can be None.
data_root: All the relative input paths are based on this location.
input_paths: Input file patterns.
dataset_type: Can be 'sstable' or 'tfrecord'.
use_sampling: Whether to perform weighted sampling between different
datasets.
sampling_weights: Unnormalized sampling weights. The length should be
equal to `input_paths`.
cycle_length: The number of input Datasets to interleave from in parallel.
If set to None tf.data experimental autotuning is used.
shuffle_buffer_size: The random shuffle buffer size.
parser_fn: The function to run decoding and data augmentation. The
function takes `is_training` as an input, which is passed from here.
parser_num_parallel_calls: The number of parallel calls for `parser_fn`.
The number of CPU cores is the suggested value. If set to None tf.data
experimental autotuning is used.
max_intra_op_parallelism: if set limits the max intra op parallelism of
functions run on slices of the input.
label_map_proto_path: Path to a StringIntLabelMap which will be used to
decode the input data.
input_filter_fns: A list of functions on the dataset points which returns
true for valid data.
input_training_filter_fns: A list of functions on the dataset points which
returns true for valid data used only for training.
dense_to_ragged_batch: Whether to use ragged batching for MPNN format.
data_validator_fn: If not None, used to validate the data specified by
input_paths.
Raises:
ValueError for invalid input_paths.
"""
self
.
_is_training
=
is_training
if
data_root
:
# If an input path is absolute this does not change it.
input_paths
=
[
os
.
path
.
join
(
data_root
,
value
)
for
value
in
input_paths
]
self
.
_input_paths
=
input_paths
# Disables datasets sampling during eval.
self
.
_batch_size
=
batch_size
if
is_training
:
self
.
_use_sampling
=
use_sampling
else
:
self
.
_use_sampling
=
False
self
.
_sampling_weights
=
sampling_weights
self
.
_cycle_length
=
(
cycle_length
if
cycle_length
else
tf
.
data
.
AUTOTUNE
)
self
.
_shuffle_buffer_size
=
shuffle_buffer_size
self
.
_parser_num_parallel_calls
=
(
parser_num_parallel_calls
if
parser_num_parallel_calls
else
tf
.
data
.
AUTOTUNE
)
self
.
_max_intra_op_parallelism
=
max_intra_op_parallelism
self
.
_label_map_proto_path
=
label_map_proto_path
if
label_map_proto_path
:
name_to_id
=
label_map_util
.
get_label_map_dict
(
label_map_proto_path
)
self
.
_lookup_str_keys
=
list
(
name_to_id
.
keys
())
self
.
_lookup_int_values
=
list
(
name_to_id
.
values
())
self
.
_parser_fn
=
parser_fn
self
.
_input_filter_fns
=
input_filter_fns
or
[]
if
is_training
and
input_training_filter_fns
:
self
.
_input_filter_fns
.
extend
(
input_training_filter_fns
)
self
.
_dataset_type
=
dataset_type
self
.
_dense_to_ragged_batch
=
dense_to_ragged_batch
if
data_validator_fn
is
not
None
:
data_validator_fn
(
self
.
_input_paths
)
@
property
def
batch_size
(
self
):
return
self
.
_batch_size
def
__call__
(
self
,
params
:
cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Read and parse input datasets, return a tf.data.Dataset object."""
# TPUEstimator passes the batch size through params.
if
params
is
not
None
and
'batch_size'
in
params
:
batch_size
=
params
[
'batch_size'
]
else
:
batch_size
=
self
.
_batch_size
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
with
tf
.
name_scope
(
'input_reader'
):
dataset
=
self
.
_build_dataset_from_records
()
dataset_parser_fn
=
self
.
_build_dataset_parser_fn
()
dataset
=
dataset
.
map
(
dataset_parser_fn
,
num_parallel_calls
=
self
.
_parser_num_parallel_calls
)
for
filter_fn
in
self
.
_input_filter_fns
:
dataset
=
dataset
.
filter
(
filter_fn
)
if
self
.
_dense_to_ragged_batch
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
dense_to_ragged_batch
(
batch_size
=
per_replica_batch_size
,
drop_remainder
=
True
))
else
:
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
AUTOTUNE
)
return
dataset
def
_fetch_dataset
(
self
,
filename
:
str
)
->
tf
.
data
.
Dataset
:
"""Fetch dataset depending on type.
Args:
filename: Location of dataset.
Returns:
Tf Dataset.
"""
data_cls
=
dataset_fn
.
pick_dataset_fn
(
self
.
_dataset_type
)
data
=
data_cls
([
filename
])
return
data
def
_build_dataset_parser_fn
(
self
)
->
Callable
[...,
tf
.
Tensor
]:
"""Depending on label_map and storage type, build a parser_fn."""
# Parse the fetched records to input tensors for model function.
if
self
.
_label_map_proto_path
:
lookup_initializer
=
tf
.
lookup
.
KeyValueTensorInitializer
(
keys
=
tf
.
constant
(
self
.
_lookup_str_keys
,
dtype
=
tf
.
string
),
values
=
tf
.
constant
(
self
.
_lookup_int_values
,
dtype
=
tf
.
int32
))
name_to_id_table
=
tf
.
lookup
.
StaticHashTable
(
initializer
=
lookup_initializer
,
default_value
=
0
)
parser_fn
=
self
.
_parser_fn
(
is_training
=
self
.
_is_training
,
label_lookup_table
=
name_to_id_table
)
else
:
parser_fn
=
self
.
_parser_fn
(
is_training
=
self
.
_is_training
)
return
parser_fn
def
_build_dataset_from_records
(
self
)
->
tf
.
data
.
Dataset
:
"""Build a tf.data.Dataset object from input SSTables.
If the input data come from multiple SSTables, use the user defined sampling
weights to perform sampling. For example, if the sampling weights is
[1., 2.], the second dataset will be sampled twice more often than the first
one.
Returns:
Dataset built from SSTables.
Raises:
ValueError for inability to find SSTable files.
"""
all_file_patterns
=
[]
if
self
.
_use_sampling
:
for
file_pattern
in
self
.
_input_paths
:
all_file_patterns
.
append
([
file_pattern
])
# Normalize sampling probabilities.
total_weight
=
sum
(
self
.
_sampling_weights
)
sampling_probabilities
=
[
float
(
w
)
/
total_weight
for
w
in
self
.
_sampling_weights
]
else
:
all_file_patterns
.
append
(
self
.
_input_paths
)
datasets
=
[]
for
file_pattern
in
all_file_patterns
:
filenames
=
sum
(
list
(
map
(
tf
.
io
.
gfile
.
glob
,
file_pattern
)),
[])
if
not
filenames
:
raise
ValueError
(
f
'Error trying to read input files for file pattern
{
file_pattern
}
'
)
# Create a dataset of filenames and shuffle the files. In each epoch,
# the file order is shuffled again. This may help if
# per_host_input_for_training = false on TPU.
dataset
=
tf
.
data
.
Dataset
.
list_files
(
file_pattern
,
shuffle
=
self
.
_is_training
)
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
if
self
.
_max_intra_op_parallelism
:
# Disable intra-op parallelism to optimize for throughput instead of
# latency.
options
=
tf
.
data
.
Options
()
options
.
experimental_threading
.
max_intra_op_parallelism
=
1
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
interleave
(
self
.
_fetch_dataset
,
cycle_length
=
self
.
_cycle_length
,
num_parallel_calls
=
self
.
_cycle_length
,
deterministic
=
(
not
self
.
_is_training
))
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
)
datasets
.
append
(
dataset
)
if
self
.
_use_sampling
:
assert
len
(
datasets
)
==
len
(
sampling_probabilities
)
dataset
=
tf
.
data
.
experimental
.
sample_from_datasets
(
datasets
,
sampling_probabilities
)
else
:
dataset
=
datasets
[
0
]
return
dataset
official/projects/unified_detector/data_loaders/tf_example_decoder.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow Example proto decoder for GOCR."""
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
tensorflow
as
tf
from
official.projects.unified_detector.utils.typing
import
TensorDict
from
official.vision.dataloaders
import
decoder
class
TfExampleDecoder
(
decoder
.
Decoder
):
"""Tensorflow Example proto decoder."""
def
__init__
(
self
,
use_instance_mask
:
bool
=
False
,
additional_class_names
:
Optional
[
Sequence
[
str
]]
=
None
,
additional_regression_names
:
Optional
[
Sequence
[
str
]]
=
None
,
num_additional_channels
:
int
=
0
):
"""Constructor.
keys_to_features is a dictionary mapping the names of the tf.Example
fields to tf features, possibly with defaults.
Uses fixed length for scalars and variable length for vectors.
Args:
use_instance_mask: if False, prevents decoding of the instance mask, which
can take a lot of resources.
additional_class_names: If not none, a list of additional class names. For
additional class name n, named image/object/${n} are expected to be an
int vector of length one, and are mapped to tensor dict key
groundtruth_${n}.
additional_regression_names: If not none, a list of additional regression
output names. For additional class name n, named image/object/${n} are
expected to be a float vector, and are mapped to tensor dict key
groundtruth_${n}.
num_additional_channels: The number of additional channels of information
present in the tf.Example proto.
"""
self
.
_num_additional_channels
=
num_additional_channels
self
.
_use_instance_mask
=
use_instance_mask
self
.
keys_to_features
=
{}
# Map names in the final tensor dict (output of `self.decode()`) to names in
# tf examples, e.g. 'groundtruth_text' -> 'image/object/text'
self
.
name_to_key
=
{}
if
use_instance_mask
:
self
.
keys_to_features
.
update
({
'image/object/mask'
:
tf
.
io
.
VarLenFeature
(
tf
.
string
),
})
# Now we have lists of standard types.
# To add new features, just add entries here.
# The tuple elements are (example name, tensor name, default value).
# If the items_to_handlers part is already set up use None for
# the tensor name.
# There are other tensor names listed as None which we probably
# want to discuss and specify.
scalar_strings
=
[
(
'image/encoded'
,
None
,
''
),
(
'image/format'
,
None
,
'jpg'
),
(
'image/additional_channels/encoded'
,
None
,
''
),
(
'image/additional_channels/format'
,
None
,
'png'
),
(
'image/label_type'
,
'label_type'
,
''
),
(
'image/key'
,
'key'
,
''
),
(
'image/source_id'
,
'source_id'
,
''
),
]
vector_strings
=
[
(
'image/attributes'
,
None
,
''
),
(
'image/object/text'
,
'groundtruth_text'
,
''
),
(
'image/object/encoded_text'
,
'groundtruth_encoded_text'
,
''
),
(
'image/object/vertices'
,
'groundtruth_vertices'
,
''
),
(
'image/object/object_type'
,
None
,
''
),
(
'image/object/language'
,
'language'
,
''
),
(
'image/object/reorderer_type'
,
None
,
''
),
(
'image/label_map_path'
,
'label_map_path'
,
''
)
]
scalar_ints
=
[
(
'image/height'
,
None
,
1
),
(
'image/width'
,
None
,
1
),
(
'image/channels'
,
None
,
3
),
]
vector_ints
=
[
(
'image/object/classes'
,
'groundtruth_classes'
,
0
),
(
'image/object/frame_id'
,
'frame_id'
,
0
),
(
'image/object/track_id'
,
'track_id'
,
0
),
(
'image/object/content_type'
,
'groundtruth_content_type'
,
0
),
]
if
additional_class_names
:
vector_ints
+=
[(
'image/object/%s'
%
name
,
'groundtruth_%s'
%
name
,
0
)
for
name
in
additional_class_names
]
# This one is not yet needed:
# scalar_floats = [
# ]
vector_floats
=
[
(
'image/object/weight'
,
'groundtruth_weight'
,
0
),
(
'image/object/rbox_tl_x'
,
None
,
0
),
(
'image/object/rbox_tl_y'
,
None
,
0
),
(
'image/object/rbox_width'
,
None
,
0
),
(
'image/object/rbox_height'
,
None
,
0
),
(
'image/object/rbox_angle'
,
None
,
0
),
(
'image/object/bbox/xmin'
,
None
,
0
),
(
'image/object/bbox/xmax'
,
None
,
0
),
(
'image/object/bbox/ymin'
,
None
,
0
),
(
'image/object/bbox/ymax'
,
None
,
0
),
]
if
additional_regression_names
:
vector_floats
+=
[(
'image/object/%s'
%
name
,
'groundtruth_%s'
%
name
,
0
)
for
name
in
additional_regression_names
]
self
.
_init_scalar_features
(
scalar_strings
,
tf
.
string
)
self
.
_init_vector_features
(
vector_strings
,
tf
.
string
)
self
.
_init_scalar_features
(
scalar_ints
,
tf
.
int64
)
self
.
_init_vector_features
(
vector_ints
,
tf
.
int64
)
self
.
_init_vector_features
(
vector_floats
,
tf
.
float32
)
def
_init_scalar_features
(
self
,
feature_list
:
List
[
Tuple
[
str
,
Optional
[
str
],
Union
[
str
,
int
,
float
]]],
ftype
:
tf
.
dtypes
.
DType
)
->
None
:
for
entry
in
feature_list
:
self
.
keys_to_features
[
entry
[
0
]]
=
tf
.
io
.
FixedLenFeature
(
(),
ftype
,
default_value
=
entry
[
2
])
if
entry
[
1
]
is
not
None
:
self
.
name_to_key
[
entry
[
1
]]
=
entry
[
0
]
def
_init_vector_features
(
self
,
feature_list
:
List
[
Tuple
[
str
,
Optional
[
str
],
Union
[
str
,
int
,
float
]]],
ftype
:
tf
.
dtypes
.
DType
)
->
None
:
for
entry
in
feature_list
:
self
.
keys_to_features
[
entry
[
0
]]
=
tf
.
io
.
VarLenFeature
(
ftype
)
if
entry
[
1
]
is
not
None
:
self
.
name_to_key
[
entry
[
1
]]
=
entry
[
0
]
def
_decode_png_instance_masks
(
self
,
keys_to_tensors
:
TensorDict
)
->
tf
.
Tensor
:
"""Decode PNG instance segmentation masks and stack into dense tensor.
The instance segmentation masks are reshaped to [num_instances, height,
width].
Args:
keys_to_tensors: A dictionary from keys to tensors.
Returns:
A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
"""
def
decode_png_mask
(
image_buffer
):
image
=
tf
.
squeeze
(
tf
.
image
.
decode_image
(
image_buffer
,
channels
=
1
),
axis
=
2
)
image
.
set_shape
([
None
,
None
])
image
=
tf
.
to_float
(
tf
.
greater
(
image
,
0
))
return
image
png_masks
=
keys_to_tensors
[
'image/object/mask'
]
height
=
keys_to_tensors
[
'image/height'
]
width
=
keys_to_tensors
[
'image/width'
]
if
isinstance
(
png_masks
,
tf
.
SparseTensor
):
png_masks
=
tf
.
sparse_tensor_to_dense
(
png_masks
,
default_value
=
''
)
return
tf
.
cond
(
tf
.
greater
(
tf
.
size
(
png_masks
),
0
),
lambda
:
tf
.
map_fn
(
decode_png_mask
,
png_masks
,
dtype
=
tf
.
float32
),
lambda
:
tf
.
zeros
(
tf
.
to_int32
(
tf
.
stack
([
0
,
height
,
width
]))))
def
_decode_image
(
self
,
parsed_tensors
:
TensorDict
,
channel
:
int
=
3
)
->
TensorDict
:
"""Decodes the image and set its shape (H, W are dynamic; C is fixed)."""
image
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/encoded'
],
channels
=
channel
)
image
.
set_shape
([
None
,
None
,
channel
])
return
{
'image'
:
image
}
def
_decode_additional_channels
(
self
,
parsed_tensors
:
TensorDict
,
channel
:
int
=
3
)
->
TensorDict
:
"""Decodes the additional channels and set its static shape."""
channels
=
tf
.
io
.
decode_image
(
parsed_tensors
[
'image/additional_channels/encoded'
],
channels
=
channel
)
channels
.
set_shape
([
None
,
None
,
channel
])
return
{
'additional_channels'
:
channels
}
def
_decode_boxes
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
xmin
=
parsed_tensors
[
'image/object/bbox/xmin'
]
xmax
=
parsed_tensors
[
'image/object/bbox/xmax'
]
ymin
=
parsed_tensors
[
'image/object/bbox/ymin'
]
ymax
=
parsed_tensors
[
'image/object/bbox/ymax'
]
return
{
'groundtruth_aligned_boxes'
:
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=-
1
)
}
def
_decode_rboxes
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Concat rbox coordinates: [left, top, box_width, box_height, angle]."""
top_left_x
=
parsed_tensors
[
'image/object/rbox_tl_x'
]
top_left_y
=
parsed_tensors
[
'image/object/rbox_tl_y'
]
width
=
parsed_tensors
[
'image/object/rbox_width'
]
height
=
parsed_tensors
[
'image/object/rbox_height'
]
angle
=
parsed_tensors
[
'image/object/rbox_angle'
]
return
{
'groundtruth_boxes'
:
tf
.
stack
([
top_left_x
,
top_left_y
,
width
,
height
,
angle
],
axis
=-
1
)
}
def
_decode_masks
(
self
,
parsed_tensors
:
TensorDict
)
->
TensorDict
:
"""Decode a set of PNG masks to the tf.float32 tensors."""
def
_decode_png_mask
(
png_bytes
):
mask
=
tf
.
squeeze
(
tf
.
io
.
decode_png
(
png_bytes
,
channels
=
1
,
dtype
=
tf
.
uint8
),
axis
=-
1
)
mask
=
tf
.
cast
(
mask
,
dtype
=
tf
.
float32
)
mask
.
set_shape
([
None
,
None
])
return
mask
height
=
parsed_tensors
[
'image/height'
]
width
=
parsed_tensors
[
'image/width'
]
masks
=
parsed_tensors
[
'image/object/mask'
]
masks
=
tf
.
cond
(
pred
=
tf
.
greater
(
tf
.
size
(
input
=
masks
),
0
),
true_fn
=
lambda
:
tf
.
map_fn
(
_decode_png_mask
,
masks
,
dtype
=
tf
.
float32
),
false_fn
=
lambda
:
tf
.
zeros
([
0
,
height
,
width
],
dtype
=
tf
.
float32
))
return
{
'groundtruth_instance_masks'
:
masks
}
def
decode
(
self
,
tf_example_string_tensor
:
tf
.
string
):
"""Decodes serialized tensorflow example and returns a tensor dictionary.
Args:
tf_example_string_tensor: A string tensor holding a serialized tensorflow
example proto.
Returns:
A dictionary contains a subset of the following, depends on the inputs:
image: A uint8 tensor of shape [height, width, 3] containing the image.
source_id: A string tensor contains image fingerprint.
key: A string tensor contains the unique sha256 hash key.
label_type: Either `full` or `partial`. `full` means all the text are
fully labeled, `partial` otherwise. Currently, this is used by E2E
model. If an input image is fully labeled, we update the weights of
both the detection and the recognizer. Otherwise, only recognizer part
of the model is trained.
groundtruth_text: A string tensor list, the original transcriptions.
groundtruth_encoded_text: A string tensor list, the class ids for the
atoms in the text, after applying the reordering algorithm, in string
form. For example "90,71,85,69,86,85,93,90,71,91,1,71,85,93,90,71".
This depends on the class label map provided to the conversion
program. These are 0 based, with -1 for OOV symbols.
groundtruth_classes: A int32 tensor of shape [num_boxes] contains the
class id. Note this is 1 based, 0 is reserved for background class.
groundtruth_content_type: A int32 tensor of shape [num_boxes] contains
the content type. Values correspond to PageLayoutEntity::ContentType.
groundtruth_weight: A int32 tensor of shape [num_boxes], either 0 or 1.
If a region has weight 0, it will be ignored when computing the
losses.
groundtruth_boxes: A float tensor of shape [num_boxes, 5] contains the
groundtruth rotated rectangles. Each row is in [left, top, box_width,
box_height, angle] order, absolute coordinates are used.
groundtruth_aligned_boxes: A float tensor of shape [num_boxes, 4]
contains the groundtruth axis-aligned rectangles. Each row is in
[ymin, xmin, ymax, xmax] order. Currently, this is used to store
groundtruth symbol boxes.
groundtruth_vertices: A string tensor list contains encoded normalized
box or polygon coordinates. E.g. `x1,y1,x2,y2,x3,y3,x4,y4`.
groundtruth_instance_masks: A float tensor of shape [num_boxes, height,
width] contains binarized image sized instance segmentation masks.
`1.0` for positive region, `0.0` otherwise. None if not in tfe.
frame_id: A int32 tensor of shape [num_boxes], either `0` or `1`.
`0` means object comes from first image, `1` means second.
track_id: A int32 tensor of shape [num_boxes], where value indicates
identity across frame indices.
additional_channels: A uint8 tensor of shape [H, W, C] representing some
features.
"""
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
self
.
keys_to_features
)
for
k
in
parsed_tensors
:
if
isinstance
(
parsed_tensors
[
k
],
tf
.
SparseTensor
):
if
parsed_tensors
[
k
].
dtype
==
tf
.
string
:
parsed_tensors
[
k
]
=
tf
.
sparse
.
to_dense
(
parsed_tensors
[
k
],
default_value
=
''
)
else
:
parsed_tensors
[
k
]
=
tf
.
sparse
.
to_dense
(
parsed_tensors
[
k
],
default_value
=
0
)
decoded_tensors
=
{}
decoded_tensors
.
update
(
self
.
_decode_image
(
parsed_tensors
))
decoded_tensors
.
update
(
self
.
_decode_rboxes
(
parsed_tensors
))
decoded_tensors
.
update
(
self
.
_decode_boxes
(
parsed_tensors
))
if
self
.
_use_instance_mask
:
decoded_tensors
[
'groundtruth_instance_masks'
]
=
self
.
_decode_png_instance_masks
(
parsed_tensors
)
if
self
.
_num_additional_channels
:
decoded_tensors
.
update
(
self
.
_decode_additional_channels
(
parsed_tensors
,
self
.
_num_additional_channels
))
# other attributes:
for
key
in
self
.
name_to_key
:
if
key
not
in
decoded_tensors
:
decoded_tensors
[
key
]
=
parsed_tensors
[
self
.
name_to_key
[
key
]]
if
'groundtruth_instance_masks'
not
in
decoded_tensors
:
decoded_tensors
[
'groundtruth_instance_masks'
]
=
None
return
decoded_tensors
official/projects/unified_detector/data_loaders/universal_detection_parser.py
0 → 100644
View file @
ac1f5735
This diff is collapsed.
Click to expand it.
official/projects/unified_detector/docs/images/task.png
0 → 100644
View file @
ac1f5735
522 KB
official/projects/unified_detector/external_configurables.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrap external code in gin."""
import
gin
import
gin.tf.external_configurables
import
tensorflow
as
tf
# Tensorflow.
gin
.
external_configurable
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
)
official/projects/unified_detector/modeling/universal_detector.py
0 → 100644
View file @
ac1f5735
This diff is collapsed.
Click to expand it.
official/projects/unified_detector/registry_imports.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official.projects.unified_detector
import
external_configurables
from
official.projects.unified_detector.configs
import
ocr_config
from
official.projects.unified_detector.tasks
import
ocr_task
from
official.vision
import
registry_imports
official/projects/unified_detector/requirements.txt
0 → 100644
View file @
ac1f5735
tf-nightly
gin-config
opencv-python==4.1.2.30
absl-py>=1.0.0
shapely>=1.8.1
apache_beam>=2.37.0
matplotlib>=3.5.1
notebook>=6.4.10
official/projects/unified_detector/run_inference.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""A binary to run unified detector."""
import
json
import
os
from
typing
import
Any
,
Dict
,
Sequence
,
Union
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
cv2
import
gin
import
numpy
as
np
import
tensorflow
as
tf
import
tqdm
from
official.projects.unified_detector
import
external_configurables
# pylint: disable=unused-import
from
official.projects.unified_detector.modeling
import
universal_detector
from
official.projects.unified_detector.utils
import
utilities
# group two lines into a paragraph if affinity score higher than this
_PARA_GROUP_THR
=
0.5
# MODEL spec
_GIN_FILE
=
flags
.
DEFINE_string
(
'gin_file'
,
None
,
'Path to the Gin file that defines the model.'
)
_CKPT_PATH
=
flags
.
DEFINE_string
(
'ckpt_path'
,
None
,
'Path to the checkpoint directory.'
)
_IMG_SIZE
=
flags
.
DEFINE_integer
(
'img_size'
,
1024
,
'Size of the image fed to the model.'
)
# Input & Output
# Note that, all images specified by `img_file` and `img_dir` will be processed.
_IMG_FILE
=
flags
.
DEFINE_multi_string
(
'img_file'
,
[],
'Paths to the images.'
)
_IMG_DIR
=
flags
.
DEFINE_multi_string
(
'img_dir'
,
[],
'Paths to the image directories.'
)
_OUTPUT_PATH
=
flags
.
DEFINE_string
(
'output_path'
,
None
,
'Path for the output.'
)
_VIS_DIR
=
flags
.
DEFINE_string
(
'vis_dir'
,
None
,
'Path for the visualization output.'
)
def
_preprocess
(
raw_image
:
np
.
ndarray
)
->
Union
[
np
.
ndarray
,
float
]:
"""Convert a raw image to properly resized, padded, and normalized ndarray."""
# (1) convert to tf.Tensor and float32.
img_tensor
=
tf
.
convert_to_tensor
(
raw_image
,
dtype
=
tf
.
float32
)
# (2) pad to square.
height
,
width
=
img_tensor
.
shape
[:
2
]
maximum_side
=
tf
.
maximum
(
height
,
width
)
height_pad
=
maximum_side
-
height
width_pad
=
maximum_side
-
width
img_tensor
=
tf
.
pad
(
img_tensor
,
[[
0
,
height_pad
],
[
0
,
width_pad
],
[
0
,
0
]],
constant_values
=
127
)
ratio
=
maximum_side
/
_IMG_SIZE
.
value
# (3) resize long side to the maximum length.
img_tensor
=
tf
.
image
.
resize
(
img_tensor
,
(
_IMG_SIZE
.
value
,
_IMG_SIZE
.
value
))
img_tensor
=
tf
.
cast
(
img_tensor
,
tf
.
uint8
)
# (4) normalize
img_tensor
=
utilities
.
normalize_image_to_range
(
img_tensor
)
# (5) Add batch dimension and return as numpy array.
return
tf
.
expand_dims
(
img_tensor
,
0
).
numpy
(),
float
(
ratio
)
def
load_model
()
->
tf
.
keras
.
layers
.
Layer
:
gin
.
parse_config_file
(
_GIN_FILE
.
value
)
model
=
universal_detector
.
UniversalDetector
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_path
=
_CKPT_PATH
.
value
logging
.
info
(
'Load ckpt from: %s'
,
ckpt_path
)
ckpt
.
restore
(
ckpt_path
).
expect_partial
()
return
model
def
inference
(
img_file
:
str
,
model
:
tf
.
keras
.
layers
.
Layer
)
->
Dict
[
str
,
Any
]:
"""Inference step."""
img
=
cv2
.
cvtColor
(
cv2
.
imread
(
img_file
),
cv2
.
COLOR_BGR2RGB
)
img_ndarray
,
ratio
=
_preprocess
(
img
)
output_dict
=
model
.
serve
(
img_ndarray
)
class_tensor
=
output_dict
[
'classes'
].
numpy
()
mask_tensor
=
output_dict
[
'masks'
].
numpy
()
group_tensor
=
output_dict
[
'groups'
].
numpy
()
indices
=
np
.
where
(
class_tensor
[
0
])[
0
].
tolist
()
# indices of positive slots.
mask_list
=
[
mask_tensor
[
0
,
:,
:,
index
]
for
index
in
indices
]
# List of mask ndarray.
# Form lines and words
lines
=
[]
line_indices
=
[]
for
index
,
mask
in
tqdm
.
tqdm
(
zip
(
indices
,
mask_list
)):
line
=
{
'words'
:
[],
'text'
:
''
,
}
contours
,
_
=
cv2
.
findContours
(
(
mask
>
0.
).
astype
(
np
.
uint8
),
cv2
.
RETR_TREE
,
cv2
.
CHAIN_APPROX_SIMPLE
)[
-
2
:]
for
contour
in
contours
:
if
(
isinstance
(
contour
,
np
.
ndarray
)
and
len
(
contour
.
shape
)
==
3
and
contour
.
shape
[
0
]
>
2
and
contour
.
shape
[
1
]
==
1
and
contour
.
shape
[
2
]
==
2
):
cnt_list
=
(
contour
[:,
0
]
*
ratio
).
astype
(
np
.
int32
).
tolist
()
line
[
'words'
].
append
({
'text'
:
''
,
'vertices'
:
cnt_list
})
else
:
logging
.
error
(
'Invalid contour: %s, discarded'
,
str
(
contour
))
if
line
[
'words'
]:
lines
.
append
(
line
)
line_indices
.
append
(
index
)
# Form paragraphs
line_grouping
=
utilities
.
DisjointSet
(
len
(
line_indices
))
affinity
=
group_tensor
[
0
][
line_indices
][:,
line_indices
]
for
i1
,
i2
in
zip
(
*
np
.
where
(
affinity
>
_PARA_GROUP_THR
)):
line_grouping
.
union
(
i1
,
i2
)
line_groups
=
line_grouping
.
to_group
()
paragraphs
=
[]
for
line_group
in
line_groups
:
paragraph
=
{
'lines'
:
[]}
for
id_
in
line_group
:
paragraph
[
'lines'
].
append
(
lines
[
id_
])
if
paragraph
:
paragraphs
.
append
(
paragraph
)
return
paragraphs
def
main
(
argv
:
Sequence
[
str
])
->
None
:
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
# Get list of images
img_lists
=
[]
img_lists
.
extend
(
_IMG_FILE
.
value
)
for
img_dir
in
_IMG_DIR
.
value
:
img_lists
.
extend
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
img_dir
,
'*'
)))
logging
.
info
(
'Total number of input images: %d'
,
len
(
img_lists
))
model
=
load_model
()
vis_dis
=
_VIS_DIR
.
value
output
=
{
'annotations'
:
[]}
for
img_file
in
tqdm
.
tqdm
(
img_lists
):
output
[
'annotations'
].
append
({
'image_id'
:
img_file
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
],
'paragraphs'
:
inference
(
img_file
,
model
),
})
if
vis_dis
:
key
=
output
[
'annotations'
][
-
1
][
'image_id'
]
paragraphs
=
output
[
'annotations'
][
-
1
][
'paragraphs'
]
img
=
cv2
.
cvtColor
(
cv2
.
imread
(
img_file
),
cv2
.
COLOR_BGR2RGB
)
word_bnds
=
[]
line_bnds
=
[]
para_bnds
=
[]
for
paragraph
in
paragraphs
:
paragraph_points_list
=
[]
for
line
in
paragraph
[
'lines'
]:
line_points_list
=
[]
for
word
in
line
[
'words'
]:
word_bnds
.
append
(
np
.
array
(
word
[
'vertices'
],
np
.
int32
).
reshape
((
-
1
,
1
,
2
)))
line_points_list
.
extend
(
word
[
'vertices'
])
paragraph_points_list
.
extend
(
line_points_list
)
line_points
=
np
.
array
(
line_points_list
,
np
.
int32
)
# (N,2)
left
=
int
(
np
.
min
(
line_points
[:,
0
]))
top
=
int
(
np
.
min
(
line_points
[:,
1
]))
right
=
int
(
np
.
max
(
line_points
[:,
0
]))
bottom
=
int
(
np
.
max
(
line_points
[:,
1
]))
line_bnds
.
append
(
np
.
array
([[[
left
,
top
]],
[[
right
,
top
]],
[[
right
,
bottom
]],
[[
left
,
bottom
]]],
np
.
int32
))
para_points
=
np
.
array
(
paragraph_points_list
,
np
.
int32
)
# (N,2)
left
=
int
(
np
.
min
(
para_points
[:,
0
]))
top
=
int
(
np
.
min
(
para_points
[:,
1
]))
right
=
int
(
np
.
max
(
para_points
[:,
0
]))
bottom
=
int
(
np
.
max
(
para_points
[:,
1
]))
para_bnds
.
append
(
np
.
array
([[[
left
,
top
]],
[[
right
,
top
]],
[[
right
,
bottom
]],
[[
left
,
bottom
]]],
np
.
int32
))
for
name
,
bnds
in
zip
([
'paragraph'
,
'line'
,
'word'
],
[
para_bnds
,
line_bnds
,
word_bnds
]):
vis
=
cv2
.
polylines
(
img
,
bnds
,
True
,
(
0
,
0
,
255
),
2
)
cv2
.
imwrite
(
os
.
path
.
join
(
vis_dis
,
f
'
{
key
}
-
{
name
}
.jpg'
),
cv2
.
cvtColor
(
vis
,
cv2
.
COLOR_RGB2BGR
))
with
tf
.
io
.
gfile
.
GFile
(
_OUTPUT_PATH
.
value
,
mode
=
'w'
)
as
f
:
f
.
write
(
json
.
dumps
(
output
,
ensure_ascii
=
False
,
indent
=
2
))
if
__name__
==
'__main__'
:
flags
.
mark_flags_as_required
([
'gin_file'
,
'ckpt_path'
,
'output_path'
])
app
.
run
(
main
)
official/projects/unified_detector/tasks/all_models.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import all models.
All model files are imported here so that they can be referenced in Gin. Also,
importing here avoids making ocr_task.py too messy.
"""
# pylint: disable=unused-import
from
official.projects.unified_detector.modeling
import
universal_detector
official/projects/unified_detector/tasks/ocr_task.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task definition for ocr."""
from
typing
import
Callable
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Union
import
gin
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.projects.unified_detector.configs
import
ocr_config
from
official.projects.unified_detector.data_loaders
import
input_reader
from
official.projects.unified_detector.tasks
import
all_models
# pylint: disable=unused-import
from
official.projects.unified_detector.utils
import
typing
NestedTensorDict
=
typing
.
NestedTensorDict
ModelType
=
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
]
@
task_factory
.
register_task_cls
(
ocr_config
.
OcrTaskConfig
)
@
gin
.
configurable
class
OcrTask
(
base_task
.
Task
):
"""Defining the OCR training task."""
_loss_items
=
[]
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
model_fn
:
Callable
[...,
ModelType
]
=
gin
.
REQUIRED
):
super
().
__init__
(
params
,
logging_dir
,
name
)
self
.
_modef_fn
=
model_fn
def
build_model
(
self
)
->
ModelType
:
"""Build and return the model, record the loss items as well."""
model
=
self
.
_modef_fn
()
self
.
_loss_items
.
extend
(
model
.
loss_items
)
return
model
def
build_inputs
(
self
,
params
:
cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Build the tf.data.Dataset instance."""
return
input_reader
.
InputFn
(
is_training
=
params
.
is_training
)({},
input_context
)
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
Sequence
[
tf
.
keras
.
metrics
.
Metric
]:
"""Build the metrics (currently, only for loss summaries in TensorBoard)."""
del
training
metrics
=
[]
# Add loss items
for
name
in
self
.
_loss_items
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
# TODO(longshangbang): add evaluation metrics
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
NestedTensorDict
,
NestedTensorDict
],
model
:
ModelType
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
Sequence
[
tf
.
keras
.
metrics
.
Metric
]]
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
features
,
labels
=
inputs
input_dict
=
{
"features"
:
features
}
if
self
.
task_config
.
model_call_needs_labels
:
input_dict
[
"labels"
]
=
labels
is_mixed_precision
=
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
)
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
**
input_dict
,
training
=
True
)
loss
,
loss_dict
=
model
.
compute_losses
(
labels
=
labels
,
outputs
=
outputs
)
loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
if
is_mixed_precision
:
loss
=
optimizer
.
get_scaled_loss
(
loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
loss
,
tvars
)
if
is_mixed_precision
:
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
"loss"
:
loss
}
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
loss_dict
[
m
.
name
])
return
logs
official/projects/unified_detector/train.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
# pylint: disable=unused-import
from
official.projects.unified_detector
import
registry_imports
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/projects/unified_detector/utils/typing.py
0 → 100644
View file @
ac1f5735
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Typing extension."""
from
typing
import
Dict
,
Union
import
numpy
as
np
import
tensorflow
as
tf
NpDict
=
Dict
[
str
,
np
.
ndarray
]
FeaturesAndLabelsType
=
Dict
[
str
,
Dict
[
str
,
tf
.
Tensor
]]
TensorDict
=
Dict
[
Union
[
str
,
int
],
tf
.
Tensor
]
NestedTensorDict
=
Dict
[
Union
[
str
,
int
],
Union
[
tf
.
Tensor
,
TensorDict
]]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment