Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
775
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
472 additions
and
50 deletions
+472
-50
official/projects/centernet/configs/experiments/coco-centernet-hourglass-gpu.yaml
...net/configs/experiments/coco-centernet-hourglass-gpu.yaml
+4
-4
official/projects/centernet/configs/experiments/coco-centernet-hourglass-tpu.yaml
...net/configs/experiments/coco-centernet-hourglass-tpu.yaml
+6
-6
official/projects/centernet/dataloaders/__init__.py
official/projects/centernet/dataloaders/__init__.py
+14
-0
official/projects/centernet/dataloaders/centernet_input.py
official/projects/centernet/dataloaders/centernet_input.py
+8
-8
official/projects/centernet/losses/__init__.py
official/projects/centernet/losses/__init__.py
+14
-0
official/projects/centernet/losses/centernet_losses.py
official/projects/centernet/losses/centernet_losses.py
+1
-1
official/projects/centernet/losses/centernet_losses_test.py
official/projects/centernet/losses/centernet_losses_test.py
+2
-2
official/projects/centernet/modeling/__init__.py
official/projects/centernet/modeling/__init__.py
+14
-0
official/projects/centernet/modeling/backbones/__init__.py
official/projects/centernet/modeling/backbones/__init__.py
+14
-0
official/projects/centernet/modeling/backbones/hourglass.py
official/projects/centernet/modeling/backbones/hourglass.py
+5
-5
official/projects/centernet/modeling/backbones/hourglass_test.py
...l/projects/centernet/modeling/backbones/hourglass_test.py
+5
-5
official/projects/centernet/modeling/centernet_model.py
official/projects/centernet/modeling/centernet_model.py
+1
-1
official/projects/centernet/modeling/centernet_model_test.py
official/projects/centernet/modeling/centernet_model_test.py
+7
-7
official/projects/centernet/modeling/heads/__init__.py
official/projects/centernet/modeling/heads/__init__.py
+14
-0
official/projects/centernet/modeling/heads/centernet_head.py
official/projects/centernet/modeling/heads/centernet_head.py
+3
-4
official/projects/centernet/modeling/heads/centernet_head_test.py
.../projects/centernet/modeling/heads/centernet_head_test.py
+2
-2
official/projects/centernet/modeling/layers/__init__.py
official/projects/centernet/modeling/layers/__init__.py
+14
-0
official/projects/centernet/modeling/layers/cn_nn_blocks.py
official/projects/centernet/modeling/layers/cn_nn_blocks.py
+2
-2
official/projects/centernet/modeling/layers/cn_nn_blocks_test.py
...l/projects/centernet/modeling/layers/cn_nn_blocks_test.py
+3
-3
official/projects/centernet/modeling/layers/detection_generator.py
...projects/centernet/modeling/layers/detection_generator.py
+339
-0
No files found.
Too many changes to show.
To preserve performance only
775 of 775+
files are displayed.
Plain diff
Email patch
official/
vision/beta/
projects/centernet/configs/experiments/coco-centernet-hourglass-gpu.yaml
→
official/projects/centernet/configs/experiments/coco-centernet-hourglass-gpu.yaml
View file @
32e4ca51
...
...
@@ -38,11 +38,11 @@ task:
per_category_metrics
:
false
weight_decay
:
0.0005
gradient_clip_norm
:
10.0
annotation_file
:
'
coco/instances_val2017.json'
init_checkpoint
:
'
/placer/prod/scratch/home
/tf
-
model
-
garden
-dev
/vision/centernet/extremenet_hg104_512x512_coco17
/2021-10-19'
annotation_file
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/instances_val2017.json'
init_checkpoint
:
gs:/
/tf
_
model
_
garden/vision/centernet/extremenet_hg104_512x512_coco17
init_checkpoint_modules
:
'
backbone'
train_data
:
input_path
:
'
coco/train*'
input_path
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/train*'
drop_remainder
:
true
dtype
:
'
float16'
global_batch_size
:
64
...
...
@@ -57,7 +57,7 @@ task:
aug_rand_contrast
:
true
odapi_augmentation
:
true
validation_data
:
input_path
:
'
coco/val*'
input_path
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/val*'
drop_remainder
:
false
dtype
:
'
float16'
global_batch_size
:
16
...
...
official/
vision/beta/
projects/centernet/configs/experiments/coco-centernet-hourglass-tpu.yaml
→
official/projects/centernet/configs/experiments/coco-centernet-hourglass-tpu.yaml
View file @
32e4ca51
...
...
@@ -37,11 +37,11 @@ task:
per_category_metrics
:
false
weight_decay
:
0.0005
gradient_clip_norm
:
10.0
annotation_file
:
'
coco/instances_val2017.json'
init_checkpoint
:
'
/placer/prod/scratch/home
/tf
-
model
-
garden
-dev
/vision/centernet/extremenet_hg104_512x512_coco17
/2021-10-19'
annotation_file
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/instances_val2017.json'
init_checkpoint
:
gs:/
/tf
_
model
_
garden/vision/centernet/extremenet_hg104_512x512_coco17
init_checkpoint_modules
:
'
backbone'
train_data
:
input_path
:
'
coco/train*'
input_path
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/train*'
drop_remainder
:
true
dtype
:
'
bfloat16'
global_batch_size
:
128
...
...
@@ -56,14 +56,14 @@ task:
aug_rand_contrast
:
true
odapi_augmentation
:
true
validation_data
:
input_path
:
'
coco/val*'
input_path
:
'
/readahead/200M/placer/prod/home/tensorflow-performance-data/datasets/
coco/val*'
drop_remainder
:
false
dtype
:
'
bfloat16'
global_batch_size
:
1
6
global_batch_size
:
6
4
is_training
:
false
trainer
:
train_steps
:
140000
validation_steps
:
78
# 5000 /
1
6
validation_steps
:
78
# 5000 / 6
4
steps_per_loop
:
924
# 118287 / 128
validation_interval
:
924
summary_interval
:
924
...
...
official/projects/centernet/dataloaders/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
vision/beta/
projects/centernet/dataloaders/centernet_input.py
→
official/projects/centernet/dataloaders/centernet_input.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,13 +18,13 @@ from typing import Tuple
import
tensorflow
as
tf
from
official.
vision.beta.dataloader
s
import
parser
from
official.
vision.beta.dataloaders
import
util
s
from
official.
vision.b
et
a
.ops
import
box
_ops
from
official.vision.
beta.ops
import
preprocess_ops
from
official.vision.
beta.projects.centernet.op
s
import
box_list
from
official.vision.
beta.projects.centernet.
ops
import
box_
list_
ops
from
official.vision.
beta.projects.centernet.
ops
import
preprocess_ops
as
cn_prep_ops
from
official.
projects.centernet.op
s
import
box_list
from
official.
projects.centernet.ops
import
box_list_op
s
from
official.
projects.centern
et.ops
import
preprocess_ops
as
cn_prep
_ops
from
official.vision.
dataloaders
import
parser
from
official.vision.
dataloader
s
import
utils
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
CHANNEL_MEANS
=
(
104.01362025
,
114.03422265
,
119.9165958
)
...
...
official/projects/centernet/losses/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
vision/beta/
projects/centernet/losses/centernet_losses.py
→
official/projects/centernet/losses/centernet_losses.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/
vision/beta/
projects/centernet/losses/centernet_losses_test.py
→
official/projects/centernet/losses/centernet_losses_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -17,7 +17,7 @@
import
numpy
as
np
import
tensorflow
as
tf
from
official.
vision.beta.
projects.centernet.losses
import
centernet_losses
from
official.projects.centernet.losses
import
centernet_losses
LOG_2
=
np
.
log
(
2
)
LOG_3
=
np
.
log
(
3
)
...
...
official/projects/centernet/modeling/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/projects/centernet/modeling/backbones/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
vision/beta/
projects/centernet/modeling/backbones/hourglass.py
→
official/projects/centernet/modeling/backbones/hourglass.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,10 +19,10 @@ from typing import Optional
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.
vision.b
et
a
.modeling.
backbone
s
import
factory
from
official.vision.
beta.
modeling.backbones
import
mobilenet
from
official.vision.
beta.
modeling.
layer
s
import
nn_blocks
from
official.vision.
beta.projects.centernet.
modeling.layers
import
cn_
nn_blocks
from
official.
projects.centern
et.modeling.
layer
s
import
cn_nn_blocks
from
official.vision.modeling.backbones
import
factory
from
official.vision.modeling.
backbone
s
import
mobilenet
from
official.vision.modeling.layers
import
nn_blocks
HOURGLASS_SPECS
=
{
10
:
{
...
...
official/
vision/beta/
projects/centernet/modeling/backbones/hourglass_test.py
→
official/projects/centernet/modeling/backbones/hourglass_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,10 +18,10 @@ from absl.testing import parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.
vision.beta.configs
import
common
from
official.
vision.beta.
projects.centernet.co
mmon
import
registry_imports
# pylint: disable=unused-import
from
official.
vision.beta.
projects.centernet.
config
s
import
backbone
s
from
official.vision.
beta.projects.centernet.modeling.backbones
import
hourglass
from
official.
projects.centernet.common
import
registry_imports
# pylint: disable=unused-import
from
official.projects.centernet.co
nfigs
import
backbones
from
official.projects.centernet.
modeling.backbone
s
import
hourglas
s
from
official.vision.
configs
import
common
class
HourglassTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
official/
vision/beta/
projects/centernet/modeling/centernet_model.py
→
official/projects/centernet/modeling/centernet_model.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/
vision/beta/
projects/centernet/modeling/centernet_model_test.py
→
official/projects/centernet/modeling/centernet_model_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -17,12 +17,12 @@
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.
vision.beta
.configs
import
common
from
official.
vision.beta.
projects.centernet.
configs
import
backbones
from
official.
vision.beta.
projects.centernet.modeling
import
centernet_model
from
official.
vision.beta.
projects.centernet.modeling.
backbone
s
import
hourglass
from
official.
vision.beta.
projects.centernet.modeling.
head
s
import
centernet_head
from
official.vision.
beta.projects.centernet.modeling.layers
import
detection_generator
from
official.
projects.centernet
.configs
import
backbones
from
official.projects.centernet.
modeling
import
centernet_model
from
official.projects.centernet.modeling
.backbones
import
hourglass
from
official.projects.centernet.modeling.
head
s
import
centernet_head
from
official.projects.centernet.modeling.
layer
s
import
detection_generator
from
official.vision.
configs
import
common
class
CenterNetTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/projects/centernet/modeling/heads/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
vision/beta/
projects/centernet/modeling/heads/centernet_head.py
→
official/projects/centernet/modeling/heads/centernet_head.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,11 +14,11 @@
"""Contains the definitions of head for CenterNet."""
from
typing
import
Any
,
Mapping
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Mapping
import
tensorflow
as
tf
from
official.
vision.beta.
projects.centernet.modeling.layers
import
cn_nn_blocks
from
official.projects.centernet.modeling.layers
import
cn_nn_blocks
class
CenterNetHead
(
tf
.
keras
.
Model
):
...
...
@@ -61,7 +61,6 @@ class CenterNetHead(tf.keras.Model):
self
.
_heatmap_bias
=
heatmap_bias
self
.
_num_inputs
=
len
(
input_levels
)
input_levels
=
sorted
(
self
.
_input_specs
.
keys
())
inputs
=
{
level
:
tf
.
keras
.
layers
.
Input
(
shape
=
self
.
_input_specs
[
level
][
1
:])
for
level
in
input_levels
}
outputs
=
{}
...
...
official/
vision/beta/
projects/centernet/modeling/heads/centernet_head_test.py
→
official/projects/centernet/modeling/heads/centernet_head_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,7 +18,7 @@ from absl.testing import parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.
vision.beta.
projects.centernet.modeling.heads
import
centernet_head
from
official.projects.centernet.modeling.heads
import
centernet_head
class
CenterNetHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
official/projects/centernet/modeling/layers/__init__.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/
vision/beta/
projects/centernet/modeling/layers/cn_nn_blocks.py
→
official/projects/centernet/modeling/layers/cn_nn_blocks.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,7 +18,7 @@ from typing import List, Optional
import
tensorflow
as
tf
from
official.vision.
beta.
modeling.layers
import
nn_blocks
from
official.vision.modeling.layers
import
nn_blocks
def
_apply_blocks
(
inputs
,
blocks
):
...
...
official/
vision/beta/
projects/centernet/modeling/layers/cn_nn_blocks_test.py
→
official/projects/centernet/modeling/layers/cn_nn_blocks_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,8 +21,8 @@ from absl.testing import parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.
vision.b
et
a
.modeling.layers
import
nn_blocks
from
official.vision.
beta.projects.centernet.
modeling.layers
import
cn_
nn_blocks
from
official.
projects.centern
et.modeling.layers
import
cn_
nn_blocks
from
official.vision.modeling.layers
import
nn_blocks
class
HourglassBlockPyTorch
(
tf
.
keras
.
layers
.
Layer
):
...
...
official/projects/centernet/modeling/layers/detection_generator.py
0 → 100644
View file @
32e4ca51
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection generator for centernet.
Parses predictions from the CenterNet head into the final bounding boxes,
confidences, and classes. This class contains repurposed methods from the
TensorFlow Object Detection API
in: https://github.com/tensorflow/models/blob/master/research/object_detection
/meta_architectures/center_net_meta_arch.py
"""
from
typing
import
Any
,
Mapping
import
tensorflow
as
tf
from
official.projects.centernet.ops
import
loss_ops
from
official.projects.centernet.ops
import
nms_ops
from
official.vision.ops
import
box_ops
class
CenterNetDetectionGenerator
(
tf
.
keras
.
layers
.
Layer
):
"""CenterNet Detection Generator."""
def
__init__
(
self
,
input_image_dims
:
int
=
512
,
net_down_scale
:
int
=
4
,
max_detections
:
int
=
100
,
peak_error
:
float
=
1e-6
,
peak_extract_kernel_size
:
int
=
3
,
class_offset
:
int
=
1
,
use_nms
:
bool
=
False
,
nms_pre_thresh
:
float
=
0.1
,
nms_thresh
:
float
=
0.4
,
**
kwargs
):
"""Initialize CenterNet Detection Generator.
Args:
input_image_dims: An `int` that specifies the input image size.
net_down_scale: An `int` that specifies stride of the output.
max_detections: An `int` specifying the maximum number of bounding
boxes generated. This is an upper bound, so the number of generated
boxes may be less than this due to thresholding/non-maximum suppression.
peak_error: A `float` for determining non-valid heatmap locations to mask.
peak_extract_kernel_size: An `int` indicating the kernel size used when
performing max-pool over the heatmaps to detect valid center locations
from its neighbors. From the paper, set this to 3 to detect valid.
locations that have responses greater than its 8-connected neighbors
class_offset: An `int` indicating to add an offset to the class
prediction if the dataset labels have been shifted.
use_nms: A `bool` for whether or not to use non-maximum suppression to
filter the bounding boxes.
nms_pre_thresh: A `float` for pre-nms threshold.
nms_thresh: A `float` for nms threshold.
**kwargs: Additional keyword arguments to be passed.
"""
super
(
CenterNetDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
# Object center selection parameters
self
.
_max_detections
=
max_detections
self
.
_peak_error
=
peak_error
self
.
_peak_extract_kernel_size
=
peak_extract_kernel_size
# Used for adjusting class prediction
self
.
_class_offset
=
class_offset
# Box normalization parameters
self
.
_net_down_scale
=
net_down_scale
self
.
_input_image_dims
=
input_image_dims
self
.
_use_nms
=
use_nms
self
.
_nms_pre_thresh
=
nms_pre_thresh
self
.
_nms_thresh
=
nms_thresh
def
process_heatmap
(
self
,
feature_map
:
tf
.
Tensor
,
kernel_size
:
int
)
->
tf
.
Tensor
:
"""Processes the heatmap into peaks for box selection.
Given a heatmap, this function first masks out nearby heatmap locations of
the same class using max-pooling such that, ideally, only one center for the
object remains. Then, center locations are masked according to their scores
in comparison to a threshold. NOTE: Repurposed from Google OD API.
Args:
feature_map: A Tensor with shape [batch_size, height, width, num_classes]
which is the center heatmap predictions.
kernel_size: An integer value for max-pool kernel size.
Returns:
A Tensor with the same shape as the input but with non-valid center
prediction locations masked out.
"""
feature_map
=
tf
.
math
.
sigmoid
(
feature_map
)
if
not
kernel_size
or
kernel_size
==
1
:
feature_map_peaks
=
feature_map
else
:
feature_map_max_pool
=
tf
.
nn
.
max_pool
(
feature_map
,
ksize
=
kernel_size
,
strides
=
1
,
padding
=
'SAME'
)
feature_map_peak_mask
=
tf
.
math
.
abs
(
feature_map
-
feature_map_max_pool
)
<
self
.
_peak_error
# Zero out everything that is not a peak.
feature_map_peaks
=
(
feature_map
*
tf
.
cast
(
feature_map_peak_mask
,
feature_map
.
dtype
))
return
feature_map_peaks
def
get_top_k_peaks
(
self
,
feature_map_peaks
:
tf
.
Tensor
,
batch_size
:
int
,
width
:
int
,
num_classes
:
int
,
k
:
int
=
100
):
"""Gets the scores and indices of the top-k peaks from the feature map.
This function flattens the feature map in order to retrieve the top-k
peaks, then computes the x, y, and class indices for those scores.
NOTE: Repurposed from Google OD API.
Args:
feature_map_peaks: A `Tensor` with shape [batch_size, height,
width, num_classes] which is the processed center heatmap peaks.
batch_size: An `int` that indicates the batch size of the input.
width: An `int` that indicates the width (and also height) of the input.
num_classes: An `int` for the number of possible classes. This is also
the channel depth of the input.
k: `int`` that controls how many peaks to select.
Returns:
top_scores: A Tensor with shape [batch_size, k] containing the top-k
scores.
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
"""
# Flatten the entire prediction per batch
feature_map_peaks_flat
=
tf
.
reshape
(
feature_map_peaks
,
[
batch_size
,
-
1
])
# top_scores and top_indices have shape [batch_size, k]
top_scores
,
top_indices
=
tf
.
math
.
top_k
(
feature_map_peaks_flat
,
k
=
k
)
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices
,
x_indices
,
channel_indices
=
(
loss_ops
.
get_row_col_channel_indices_from_flattened_indices
(
top_indices
,
width
,
num_classes
))
return
top_scores
,
y_indices
,
x_indices
,
channel_indices
def
get_boxes
(
self
,
y_indices
:
tf
.
Tensor
,
x_indices
:
tf
.
Tensor
,
channel_indices
:
tf
.
Tensor
,
height_width_predictions
:
tf
.
Tensor
,
offset_predictions
:
tf
.
Tensor
,
num_boxes
:
int
):
"""Organizes prediction information into the final bounding boxes.
NOTE: Repurposed from Google OD API.
Args:
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
height_width_predictions: A Tensor with shape [batch_size, height,
width, 2] containing the object size predictions.
offset_predictions: A Tensor with shape [batch_size, height, width, 2]
containing the object local offset predictions.
num_boxes: `int`, the number of boxes.
Returns:
boxes: A Tensor with shape [batch_size, num_boxes, 4] that contains the
bounding box coordinates in [y_min, x_min, y_max, x_max] format.
detection_classes: A Tensor with shape [batch_size, num_boxes] that
gives the class prediction for each box.
num_detections: Number of non-zero confidence detections made.
"""
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
# shapes of heatmap output
shape
=
tf
.
shape
(
height_width_predictions
)
batch_size
,
height
,
width
=
shape
[
0
],
shape
[
1
],
shape
[
2
]
# combined indices dtype=int32
combined_indices
=
tf
.
stack
([
loss_ops
.
multi_range
(
batch_size
,
value_repetitions
=
num_boxes
),
tf
.
reshape
(
y_indices
,
[
-
1
]),
tf
.
reshape
(
x_indices
,
[
-
1
])
],
axis
=
1
)
new_height_width
=
tf
.
gather_nd
(
height_width_predictions
,
combined_indices
)
new_height_width
=
tf
.
reshape
(
new_height_width
,
[
batch_size
,
num_boxes
,
2
])
height_width
=
tf
.
maximum
(
new_height_width
,
0.0
)
# height and widths dtype=float32
heights
=
height_width
[...,
0
]
widths
=
height_width
[...,
1
]
# Get the offsets of center points
new_offsets
=
tf
.
gather_nd
(
offset_predictions
,
combined_indices
)
offsets
=
tf
.
reshape
(
new_offsets
,
[
batch_size
,
num_boxes
,
2
])
# offsets are dtype=float32
y_offsets
=
offsets
[...,
0
]
x_offsets
=
offsets
[...,
1
]
y_indices
=
tf
.
cast
(
y_indices
,
dtype
=
heights
.
dtype
)
x_indices
=
tf
.
cast
(
x_indices
,
dtype
=
widths
.
dtype
)
detection_classes
=
channel_indices
+
self
.
_class_offset
ymin
=
y_indices
+
y_offsets
-
heights
/
2.0
xmin
=
x_indices
+
x_offsets
-
widths
/
2.0
ymax
=
y_indices
+
y_offsets
+
heights
/
2.0
xmax
=
x_indices
+
x_offsets
+
widths
/
2.0
ymin
=
tf
.
clip_by_value
(
ymin
,
0.
,
tf
.
cast
(
height
,
ymin
.
dtype
))
xmin
=
tf
.
clip_by_value
(
xmin
,
0.
,
tf
.
cast
(
width
,
xmin
.
dtype
))
ymax
=
tf
.
clip_by_value
(
ymax
,
0.
,
tf
.
cast
(
height
,
ymax
.
dtype
))
xmax
=
tf
.
clip_by_value
(
xmax
,
0.
,
tf
.
cast
(
width
,
xmax
.
dtype
))
boxes
=
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=
2
)
return
boxes
,
detection_classes
def
convert_strided_predictions_to_normalized_boxes
(
self
,
boxes
:
tf
.
Tensor
):
boxes
=
boxes
*
tf
.
cast
(
self
.
_net_down_scale
,
boxes
.
dtype
)
boxes
=
boxes
/
tf
.
cast
(
self
.
_input_image_dims
,
boxes
.
dtype
)
boxes
=
tf
.
clip_by_value
(
boxes
,
0.0
,
1.0
)
return
boxes
def
__call__
(
self
,
inputs
):
# Get heatmaps from decoded outputs via final hourglass stack output
all_ct_heatmaps
=
inputs
[
'ct_heatmaps'
]
all_ct_sizes
=
inputs
[
'ct_size'
]
all_ct_offsets
=
inputs
[
'ct_offset'
]
ct_heatmaps
=
all_ct_heatmaps
[
-
1
]
ct_sizes
=
all_ct_sizes
[
-
1
]
ct_offsets
=
all_ct_offsets
[
-
1
]
shape
=
tf
.
shape
(
ct_heatmaps
)
_
,
width
=
shape
[
1
],
shape
[
2
]
batch_size
,
num_channels
=
shape
[
0
],
shape
[
3
]
# Process heatmaps using 3x3 max pool and applying sigmoid
peaks
=
self
.
process_heatmap
(
feature_map
=
ct_heatmaps
,
kernel_size
=
self
.
_peak_extract_kernel_size
)
# Get top scores along with their x, y, and class
# Each has size [batch_size, k]
scores
,
y_indices
,
x_indices
,
channel_indices
=
self
.
get_top_k_peaks
(
feature_map_peaks
=
peaks
,
batch_size
=
batch_size
,
width
=
width
,
num_classes
=
num_channels
,
k
=
self
.
_max_detections
)
# Parse the score and indices into bounding boxes
boxes
,
classes
=
self
.
get_boxes
(
y_indices
=
y_indices
,
x_indices
=
x_indices
,
channel_indices
=
channel_indices
,
height_width_predictions
=
ct_sizes
,
offset_predictions
=
ct_offsets
,
num_boxes
=
self
.
_max_detections
)
# Normalize bounding boxes
boxes
=
self
.
convert_strided_predictions_to_normalized_boxes
(
boxes
)
# Apply nms
if
self
.
_use_nms
:
boxes
=
tf
.
expand_dims
(
boxes
,
axis
=-
2
)
multi_class_scores
=
tf
.
gather_nd
(
peaks
,
tf
.
stack
([
y_indices
,
x_indices
],
-
1
),
batch_dims
=
1
)
boxes
,
_
,
scores
=
nms_ops
.
nms
(
boxes
=
boxes
,
classes
=
multi_class_scores
,
confidence
=
scores
,
k
=
self
.
_max_detections
,
limit_pre_thresh
=
True
,
pre_nms_thresh
=
0.1
,
nms_thresh
=
0.4
)
num_det
=
tf
.
reduce_sum
(
tf
.
cast
(
scores
>
0
,
dtype
=
tf
.
int32
),
axis
=
1
)
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
[
self
.
_input_image_dims
,
self
.
_input_image_dims
])
return
{
'boxes'
:
boxes
,
'classes'
:
classes
,
'confidence'
:
scores
,
'num_detections'
:
num_det
}
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
config
=
{
'max_detections'
:
self
.
_max_detections
,
'peak_error'
:
self
.
_peak_error
,
'peak_extract_kernel_size'
:
self
.
_peak_extract_kernel_size
,
'class_offset'
:
self
.
_class_offset
,
'net_down_scale'
:
self
.
_net_down_scale
,
'input_image_dims'
:
self
.
_input_image_dims
,
'use_nms'
:
self
.
_use_nms
,
'nms_pre_thresh'
:
self
.
_nms_pre_thresh
,
'nms_thresh'
:
self
.
_nms_thresh
}
base_config
=
super
(
CenterNetDetectionGenerator
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
Prev
1
…
28
29
30
31
32
33
34
35
36
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment