Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
460890ed
Commit
460890ed
authored
Nov 01, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 406888835
parent
f2bc366e
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
4175 additions
and
0 deletions
+4175
-0
official/vision/beta/projects/centernet/modeling/layers/cn_nn_blocks.py
...n/beta/projects/centernet/modeling/layers/cn_nn_blocks.py
+329
-0
official/vision/beta/projects/centernet/modeling/layers/cn_nn_blocks_test.py
...a/projects/centernet/modeling/layers/cn_nn_blocks_test.py
+154
-0
official/vision/beta/projects/centernet/modeling/layers/detection_generator.py
...projects/centernet/modeling/layers/detection_generator.py
+340
-0
official/vision/beta/projects/centernet/ops/__init__.py
official/vision/beta/projects/centernet/ops/__init__.py
+14
-0
official/vision/beta/projects/centernet/ops/box_list.py
official/vision/beta/projects/centernet/ops/box_list.py
+215
-0
official/vision/beta/projects/centernet/ops/box_list_ops.py
official/vision/beta/projects/centernet/ops/box_list_ops.py
+350
-0
official/vision/beta/projects/centernet/ops/loss_ops.py
official/vision/beta/projects/centernet/ops/loss_ops.py
+194
-0
official/vision/beta/projects/centernet/ops/nms_ops.py
official/vision/beta/projects/centernet/ops/nms_ops.py
+121
-0
official/vision/beta/projects/centernet/ops/preprocess_ops.py
...cial/vision/beta/projects/centernet/ops/preprocess_ops.py
+496
-0
official/vision/beta/projects/centernet/ops/target_assigner.py
...ial/vision/beta/projects/centernet/ops/target_assigner.py
+417
-0
official/vision/beta/projects/centernet/ops/target_assigner_test.py
...ision/beta/projects/centernet/ops/target_assigner_test.py
+208
-0
official/vision/beta/projects/centernet/tasks/centernet.py
official/vision/beta/projects/centernet/tasks/centernet.py
+425
-0
official/vision/beta/projects/centernet/train.py
official/vision/beta/projects/centernet/train.py
+67
-0
official/vision/beta/projects/centernet/utils/checkpoints/__init__.py
...ion/beta/projects/centernet/utils/checkpoints/__init__.py
+14
-0
official/vision/beta/projects/centernet/utils/checkpoints/config_classes.py
...ta/projects/centernet/utils/checkpoints/config_classes.py
+297
-0
official/vision/beta/projects/centernet/utils/checkpoints/config_data.py
.../beta/projects/centernet/utils/checkpoints/config_data.py
+111
-0
official/vision/beta/projects/centernet/utils/checkpoints/load_weights.py
...beta/projects/centernet/utils/checkpoints/load_weights.py
+173
-0
official/vision/beta/projects/centernet/utils/checkpoints/read_checkpoints.py
.../projects/centernet/utils/checkpoints/read_checkpoints.py
+114
-0
official/vision/beta/projects/centernet/utils/tf2_centernet_checkpoint_converter.py
...cts/centernet/utils/tf2_centernet_checkpoint_converter.py
+136
-0
No files found.
official/vision/beta/projects/centernet/modeling/layers/cn_nn_blocks.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for centernet neural networks."""
from
typing
import
List
,
Optional
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers
import
nn_blocks
def
_apply_blocks
(
inputs
,
blocks
):
"""Apply blocks to inputs."""
net
=
inputs
for
block
in
blocks
:
net
=
block
(
net
)
return
net
def
_make_repeated_residual_blocks
(
reps
:
int
,
out_channels
:
int
,
use_sync_bn
:
bool
=
True
,
norm_momentum
:
float
=
0.1
,
norm_epsilon
:
float
=
1e-5
,
residual_channels
:
Optional
[
int
]
=
None
,
initial_stride
:
int
=
1
,
initial_skip_conv
:
bool
=
False
,
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
):
"""Stack Residual blocks one after the other.
Args:
reps: `int` for desired number of residual blocks
out_channels: `int`, filter depth of the final residual block
use_sync_bn: A `bool`, if True, use synchronized batch normalization.
norm_momentum: `float`, momentum for the batch normalization layers
norm_epsilon: `float`, epsilon for the batch normalization layers
residual_channels: `int`, filter depth for the first reps - 1 residual
blocks. If None, defaults to the same value as out_channels. If not
equal to out_channels, then uses a projection shortcut in the final
residual block
initial_stride: `int`, stride for the first residual block
initial_skip_conv: `bool`, if set, the first residual block uses a skip
convolution. This is useful when the number of channels in the input
are not the same as residual_channels.
kernel_initializer: A `str` for kernel initializer of convolutional layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
Returns:
blocks: A list of residual blocks to be applied in sequence.
"""
blocks
=
[]
if
residual_channels
is
None
:
residual_channels
=
out_channels
for
i
in
range
(
reps
-
1
):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride
=
initial_stride
if
i
==
0
else
1
# If the stride is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv
=
stride
>
1
if
i
==
0
and
initial_skip_conv
:
skip_conv
=
True
blocks
.
append
(
nn_blocks
.
ResidualBlock
(
filters
=
residual_channels
,
strides
=
stride
,
use_explicit_padding
=
True
,
use_projection
=
skip_conv
,
use_sync_bn
=
use_sync_bn
,
norm_momentum
=
norm_momentum
,
norm_epsilon
=
norm_epsilon
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
bias_regularizer
=
bias_regularizer
))
if
reps
==
1
:
# If there is only 1 block, the `for` loop above is not run,
# therefore we honor the requested stride in the last residual block
stride
=
initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv
=
stride
>
1
else
:
stride
=
1
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
nn_blocks
.
ResidualBlock
(
filters
=
out_channels
,
strides
=
stride
,
use_explicit_padding
=
True
,
use_projection
=
skip_conv
,
use_sync_bn
=
use_sync_bn
,
norm_momentum
=
norm_momentum
,
norm_epsilon
=
norm_epsilon
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
bias_regularizer
=
bias_regularizer
))
return
tf
.
keras
.
Sequential
(
blocks
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'centernet'
)
class
HourglassBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Hourglass module: an encoder-decoder block."""
def
__init__
(
self
,
channel_dims_per_stage
:
List
[
int
],
blocks_per_stage
:
List
[
int
],
strides
:
int
=
1
,
use_sync_bn
:
bool
=
True
,
norm_momentum
:
float
=
0.1
,
norm_epsilon
:
float
=
1e-5
,
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
**
kwargs
):
"""Initialize Hourglass module.
Args:
channel_dims_per_stage: List[int], list of filter sizes for Residual
blocks. the output channels dimensions of stages in
the network. `channel_dims[0]` is used to define the number of
channels in the first encoder block and `channel_dims[1]` is used to
define the number of channels in the second encoder block. The channels
in the recursive inner layers are defined using `channel_dims[1:]`.
For example, [nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4]
where nc is the input_channel_dimension.
blocks_per_stage: List[int], list of residual block repetitions per
down/upsample. `blocks_per_stage[0]` defines the number of blocks at the
current stage and `blocks_per_stage[1:]` is used at further stages.
For example, [2, 2, 2, 2, 2, 4].
strides: `int`, stride parameter to the Residual block.
use_sync_bn: A `bool`, if True, use synchronized batch normalization.
norm_momentum: `float`, momentum for the batch normalization layers.
norm_epsilon: `float`, epsilon for the batch normalization layers.
kernel_initializer: A `str` for kernel initializer of conv layers.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
**kwargs: Additional keyword arguments to be passed.
"""
super
(
HourglassBlock
,
self
).
__init__
(
**
kwargs
)
if
len
(
channel_dims_per_stage
)
!=
len
(
blocks_per_stage
):
raise
ValueError
(
'filter size and residual block repetition '
'lists must have the same length'
)
self
.
_num_stages
=
len
(
channel_dims_per_stage
)
-
1
self
.
_channel_dims_per_stage
=
channel_dims_per_stage
self
.
_blocks_per_stage
=
blocks_per_stage
self
.
_strides
=
strides
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_filters
=
channel_dims_per_stage
[
0
]
if
self
.
_num_stages
>
0
:
self
.
_filters_downsampled
=
channel_dims_per_stage
[
1
]
self
.
_reps
=
blocks_per_stage
[
0
]
def
build
(
self
,
input_shape
):
if
self
.
_num_stages
==
0
:
# base case, residual block repetitions in most inner part of hourglass
self
.
blocks
=
_make_repeated_residual_blocks
(
reps
=
self
.
_reps
,
out_channels
=
self
.
_filters
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
bias_regularizer
=
self
.
_bias_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
)
else
:
# outer hourglass structures
self
.
encoder_block1
=
_make_repeated_residual_blocks
(
reps
=
self
.
_reps
,
out_channels
=
self
.
_filters
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
bias_regularizer
=
self
.
_bias_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
)
self
.
encoder_block2
=
_make_repeated_residual_blocks
(
reps
=
self
.
_reps
,
out_channels
=
self
.
_filters_downsampled
,
initial_stride
=
2
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
bias_regularizer
=
self
.
_bias_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
initial_skip_conv
=
self
.
_filters
!=
self
.
_filters_downsampled
)
# recursively define inner hourglasses
self
.
inner_hg
=
type
(
self
)(
channel_dims_per_stage
=
self
.
_channel_dims_per_stage
[
1
:],
blocks_per_stage
=
self
.
_blocks_per_stage
[
1
:],
strides
=
self
.
_strides
)
# outer hourglass structures
self
.
decoder_block
=
_make_repeated_residual_blocks
(
reps
=
self
.
_reps
,
residual_channels
=
self
.
_filters_downsampled
,
out_channels
=
self
.
_filters
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_epsilon
=
self
.
_norm_epsilon
,
bias_regularizer
=
self
.
_bias_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
)
self
.
upsample_layer
=
tf
.
keras
.
layers
.
UpSampling2D
(
size
=
2
,
interpolation
=
'nearest'
)
super
(
HourglassBlock
,
self
).
build
(
input_shape
)
def
call
(
self
,
x
,
training
=
None
):
if
self
.
_num_stages
==
0
:
return
self
.
blocks
(
x
)
else
:
encoded_outputs
=
self
.
encoder_block1
(
x
)
encoded_downsampled_outputs
=
self
.
encoder_block2
(
x
)
inner_outputs
=
self
.
inner_hg
(
encoded_downsampled_outputs
)
hg_output
=
self
.
decoder_block
(
inner_outputs
)
return
self
.
upsample_layer
(
hg_output
)
+
encoded_outputs
def
get_config
(
self
):
config
=
{
'channel_dims_per_stage'
:
self
.
_channel_dims_per_stage
,
'blocks_per_stage'
:
self
.
_blocks_per_stage
,
'strides'
:
self
.
_strides
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
}
config
.
update
(
super
(
HourglassBlock
,
self
).
get_config
())
return
config
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'centernet'
)
class
CenterNetHeadConv
(
tf
.
keras
.
layers
.
Layer
):
"""Convolution block for the CenterNet head."""
def
__init__
(
self
,
output_filters
:
int
,
bias_init
:
float
,
name
:
str
,
**
kwargs
):
"""Initialize CenterNet head.
Args:
output_filters: `int`, channel depth of layer output
bias_init: `float`, value to initialize the bias vector for the final
convolution layer
name: `string`, layer name
**kwargs: Additional keyword arguments to be passed.
"""
super
(
CenterNetHeadConv
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
_output_filters
=
output_filters
self
.
_bias_init
=
bias_init
def
build
(
self
,
input_shape
):
n_channels
=
input_shape
[
-
1
]
self
.
conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
n_channels
,
kernel_size
=
(
3
,
3
),
padding
=
'same'
)
self
.
relu
=
tf
.
keras
.
layers
.
ReLU
()
# Initialize bias to the last Conv2D Layer
self
.
conv2
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_output_filters
,
kernel_size
=
(
1
,
1
),
padding
=
'valid'
,
bias_initializer
=
tf
.
constant_initializer
(
self
.
_bias_init
))
super
(
CenterNetHeadConv
,
self
).
build
(
input_shape
)
def
call
(
self
,
x
,
training
=
None
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
return
x
def
get_config
(
self
):
config
=
{
'output_filters'
:
self
.
_output_filters
,
'bias_init'
:
self
.
_bias_init
,
}
config
.
update
(
super
(
CenterNetHeadConv
,
self
).
get_config
())
return
config
official/vision/beta/projects/centernet/modeling/layers/cn_nn_blocks_test.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Centernet nn_blocks.
It is a literal translation of the PyTorch implementation.
"""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.projects.centernet.modeling.layers
import
cn_nn_blocks
class
HourglassBlockPyTorch
(
tf
.
keras
.
layers
.
Layer
):
"""An CornerNet-style implementation of the hourglass block."""
def
__init__
(
self
,
dims
,
modules
,
k
=
0
,
**
kwargs
):
"""An CornerNet-style implementation of the hourglass block.
Args:
dims: input sizes of residual blocks
modules: number of repetitions of the residual blocks in each hourglass
upsampling and downsampling
k: recursive parameter
**kwargs: Additional keyword arguments to be passed.
"""
super
(
HourglassBlockPyTorch
).
__init__
()
if
len
(
dims
)
!=
len
(
modules
):
raise
ValueError
(
'dims and modules lists must have the same length'
)
self
.
n
=
len
(
dims
)
-
1
self
.
k
=
k
self
.
modules
=
modules
self
.
dims
=
dims
self
.
_kwargs
=
kwargs
def
build
(
self
,
input_shape
):
modules
=
self
.
modules
dims
=
self
.
dims
k
=
self
.
k
kwargs
=
self
.
_kwargs
curr_mod
=
modules
[
k
]
next_mod
=
modules
[
k
+
1
]
curr_dim
=
dims
[
k
+
0
]
next_dim
=
dims
[
k
+
1
]
self
.
up1
=
self
.
make_up_layer
(
3
,
curr_dim
,
curr_dim
,
curr_mod
,
**
kwargs
)
self
.
max1
=
tf
.
keras
.
layers
.
MaxPool2D
(
strides
=
2
)
self
.
low1
=
self
.
make_hg_layer
(
3
,
curr_dim
,
next_dim
,
curr_mod
,
**
kwargs
)
if
self
.
n
-
k
>
1
:
self
.
low2
=
type
(
self
)(
dims
,
modules
,
k
=
k
+
1
,
**
kwargs
)
else
:
self
.
low2
=
self
.
make_low_layer
(
3
,
next_dim
,
next_dim
,
next_mod
,
**
kwargs
)
self
.
low3
=
self
.
make_hg_layer_revr
(
3
,
next_dim
,
curr_dim
,
curr_mod
,
**
kwargs
)
self
.
up2
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
)
self
.
merge
=
tf
.
keras
.
layers
.
Add
()
super
(
HourglassBlockPyTorch
,
self
).
build
(
input_shape
)
def
call
(
self
,
x
):
up1
=
self
.
up1
(
x
)
max1
=
self
.
max1
(
x
)
low1
=
self
.
low1
(
max1
)
low2
=
self
.
low2
(
low1
)
low3
=
self
.
low3
(
low2
)
up2
=
self
.
up2
(
low3
)
return
self
.
merge
([
up1
,
up2
])
def
make_layer
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
layers
=
[
nn_blocks
.
ResidualBlock
(
out_dim
,
1
,
use_projection
=
True
,
**
kwargs
)]
for
_
in
range
(
1
,
modules
):
layers
.
append
(
nn_blocks
.
ResidualBlock
(
out_dim
,
1
,
**
kwargs
))
return
tf
.
keras
.
Sequential
(
layers
)
def
make_layer_revr
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
layers
=
[]
for
_
in
range
(
modules
-
1
):
layers
.
append
(
nn_blocks
.
ResidualBlock
(
inp_dim
,
1
,
**
kwargs
))
layers
.
append
(
nn_blocks
.
ResidualBlock
(
out_dim
,
1
,
use_projection
=
True
,
**
kwargs
))
return
tf
.
keras
.
Sequential
(
layers
)
def
make_up_layer
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
return
self
.
make_layer
(
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
)
def
make_low_layer
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
return
self
.
make_layer
(
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
)
def
make_hg_layer
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
return
self
.
make_layer
(
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
)
def
make_hg_layer_revr
(
self
,
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
):
return
self
.
make_layer_revr
(
k
,
inp_dim
,
out_dim
,
modules
,
**
kwargs
)
class
NNBlocksTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
test_hourglass_block
(
self
):
dims
=
[
256
,
256
,
384
,
384
,
384
,
512
]
modules
=
[
2
,
2
,
2
,
2
,
2
,
4
]
model
=
cn_nn_blocks
.
HourglassBlock
(
dims
,
modules
)
test_input
=
tf
.
keras
.
Input
((
512
,
512
,
256
))
_
=
model
(
test_input
)
filter_sizes
=
[
256
,
256
,
384
,
384
,
384
,
512
]
rep_sizes
=
[
2
,
2
,
2
,
2
,
2
,
4
]
hg_test_input_shape
=
(
1
,
512
,
512
,
256
)
# bb_test_input_shape = (1, 512, 512, 3)
x_hg
=
tf
.
ones
(
shape
=
hg_test_input_shape
)
# x_bb = tf.ones(shape=bb_test_input_shape)
hg
=
cn_nn_blocks
.
HourglassBlock
(
channel_dims_per_stage
=
filter_sizes
,
blocks_per_stage
=
rep_sizes
)
hg
.
build
(
input_shape
=
hg_test_input_shape
)
out
=
hg
(
x_hg
)
self
.
assertAllEqual
(
tf
.
shape
(
out
),
hg_test_input_shape
,
'Hourglass module output shape and expected shape differ'
)
# ODAPI Test
layer
=
cn_nn_blocks
.
HourglassBlock
(
blocks_per_stage
=
[
2
,
3
,
4
,
5
,
6
],
channel_dims_per_stage
=
[
4
,
6
,
8
,
10
,
12
])
output
=
layer
(
np
.
zeros
((
2
,
64
,
64
,
4
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
64
,
64
,
4
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/centernet/modeling/layers/detection_generator.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection generator for centernet.
Parses predictions from the CenterNet head into the final bounding boxes,
confidences, and classes. This class contains repurposed methods from the
TensorFlow Object Detection API
in: https://github.com/tensorflow/models/blob/master/research/object_detection
/meta_architectures/center_net_meta_arch.py
"""
from
typing
import
Any
,
Mapping
import
tensorflow
as
tf
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.projects.centernet.ops
import
loss_ops
from
official.vision.beta.projects.centernet.ops
import
nms_ops
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'centernet'
)
class
CenterNetDetectionGenerator
(
tf
.
keras
.
layers
.
Layer
):
"""CenterNet Detection Generator."""
def
__init__
(
self
,
input_image_dims
:
int
=
512
,
net_down_scale
:
int
=
4
,
max_detections
:
int
=
100
,
peak_error
:
float
=
1e-6
,
peak_extract_kernel_size
:
int
=
3
,
class_offset
:
int
=
1
,
use_nms
:
bool
=
False
,
nms_pre_thresh
:
float
=
0.1
,
nms_thresh
:
float
=
0.4
,
**
kwargs
):
"""Initialize CenterNet Detection Generator.
Args:
input_image_dims: An `int` that specifies the input image size.
net_down_scale: An `int` that specifies stride of the output.
max_detections: An `int` specifying the maximum number of bounding
boxes generated. This is an upper bound, so the number of generated
boxes may be less than this due to thresholding/non-maximum suppression.
peak_error: A `float` for determining non-valid heatmap locations to mask.
peak_extract_kernel_size: An `int` indicating the kernel size used when
performing max-pool over the heatmaps to detect valid center locations
from its neighbors. From the paper, set this to 3 to detect valid.
locations that have responses greater than its 8-connected neighbors
class_offset: An `int` indicating to add an offset to the class
prediction if the dataset labels have been shifted.
use_nms: A `bool` for whether or not to use non-maximum suppression to
filter the bounding boxes.
nms_pre_thresh: A `float` for pre-nms threshold.
nms_thresh: A `float` for nms threshold.
**kwargs: Additional keyword arguments to be passed.
"""
super
(
CenterNetDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
# Object center selection parameters
self
.
_max_detections
=
max_detections
self
.
_peak_error
=
peak_error
self
.
_peak_extract_kernel_size
=
peak_extract_kernel_size
# Used for adjusting class prediction
self
.
_class_offset
=
class_offset
# Box normalization parameters
self
.
_net_down_scale
=
net_down_scale
self
.
_input_image_dims
=
input_image_dims
self
.
_use_nms
=
use_nms
self
.
_nms_pre_thresh
=
nms_pre_thresh
self
.
_nms_thresh
=
nms_thresh
def
process_heatmap
(
self
,
feature_map
:
tf
.
Tensor
,
kernel_size
:
int
)
->
tf
.
Tensor
:
"""Processes the heatmap into peaks for box selection.
Given a heatmap, this function first masks out nearby heatmap locations of
the same class using max-pooling such that, ideally, only one center for the
object remains. Then, center locations are masked according to their scores
in comparison to a threshold. NOTE: Repurposed from Google OD API.
Args:
feature_map: A Tensor with shape [batch_size, height, width, num_classes]
which is the center heatmap predictions.
kernel_size: An integer value for max-pool kernel size.
Returns:
A Tensor with the same shape as the input but with non-valid center
prediction locations masked out.
"""
feature_map
=
tf
.
math
.
sigmoid
(
feature_map
)
if
not
kernel_size
or
kernel_size
==
1
:
feature_map_peaks
=
feature_map
else
:
feature_map_max_pool
=
tf
.
nn
.
max_pool
(
feature_map
,
ksize
=
kernel_size
,
strides
=
1
,
padding
=
'SAME'
)
feature_map_peak_mask
=
tf
.
math
.
abs
(
feature_map
-
feature_map_max_pool
)
<
self
.
_peak_error
# Zero out everything that is not a peak.
feature_map_peaks
=
(
feature_map
*
tf
.
cast
(
feature_map_peak_mask
,
feature_map
.
dtype
))
return
feature_map_peaks
def
get_top_k_peaks
(
self
,
feature_map_peaks
:
tf
.
Tensor
,
batch_size
:
int
,
width
:
int
,
num_classes
:
int
,
k
:
int
=
100
):
"""Gets the scores and indices of the top-k peaks from the feature map.
This function flattens the feature map in order to retrieve the top-k
peaks, then computes the x, y, and class indices for those scores.
NOTE: Repurposed from Google OD API.
Args:
feature_map_peaks: A `Tensor` with shape [batch_size, height,
width, num_classes] which is the processed center heatmap peaks.
batch_size: An `int` that indicates the batch size of the input.
width: An `int` that indicates the width (and also height) of the input.
num_classes: An `int` for the number of possible classes. This is also
the channel depth of the input.
k: `int`` that controls how many peaks to select.
Returns:
top_scores: A Tensor with shape [batch_size, k] containing the top-k
scores.
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
"""
# Flatten the entire prediction per batch
feature_map_peaks_flat
=
tf
.
reshape
(
feature_map_peaks
,
[
batch_size
,
-
1
])
# top_scores and top_indices have shape [batch_size, k]
top_scores
,
top_indices
=
tf
.
math
.
top_k
(
feature_map_peaks_flat
,
k
=
k
)
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices
,
x_indices
,
channel_indices
=
(
loss_ops
.
get_row_col_channel_indices_from_flattened_indices
(
top_indices
,
width
,
num_classes
))
return
top_scores
,
y_indices
,
x_indices
,
channel_indices
def
get_boxes
(
self
,
y_indices
:
tf
.
Tensor
,
x_indices
:
tf
.
Tensor
,
channel_indices
:
tf
.
Tensor
,
height_width_predictions
:
tf
.
Tensor
,
offset_predictions
:
tf
.
Tensor
,
num_boxes
:
int
):
"""Organizes prediction information into the final bounding boxes.
NOTE: Repurposed from Google OD API.
Args:
y_indices: A Tensor with shape [batch_size, k] containing the top-k
y-indices corresponding to top_scores.
x_indices: A Tensor with shape [batch_size, k] containing the top-k
x-indices corresponding to top_scores.
channel_indices: A Tensor with shape [batch_size, k] containing the top-k
channel indices corresponding to top_scores.
height_width_predictions: A Tensor with shape [batch_size, height,
width, 2] containing the object size predictions.
offset_predictions: A Tensor with shape [batch_size, height, width, 2]
containing the object local offset predictions.
num_boxes: `int`, the number of boxes.
Returns:
boxes: A Tensor with shape [batch_size, num_boxes, 4] that contains the
bounding box coordinates in [y_min, x_min, y_max, x_max] format.
detection_classes: A Tensor with shape [batch_size, num_boxes] that
gives the class prediction for each box.
num_detections: Number of non-zero confidence detections made.
"""
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
# shapes of heatmap output
shape
=
tf
.
shape
(
height_width_predictions
)
batch_size
,
height
,
width
=
shape
[
0
],
shape
[
1
],
shape
[
2
]
# combined indices dtype=int32
combined_indices
=
tf
.
stack
([
loss_ops
.
multi_range
(
batch_size
,
value_repetitions
=
num_boxes
),
tf
.
reshape
(
y_indices
,
[
-
1
]),
tf
.
reshape
(
x_indices
,
[
-
1
])
],
axis
=
1
)
new_height_width
=
tf
.
gather_nd
(
height_width_predictions
,
combined_indices
)
new_height_width
=
tf
.
reshape
(
new_height_width
,
[
batch_size
,
num_boxes
,
2
])
height_width
=
tf
.
maximum
(
new_height_width
,
0.0
)
# height and widths dtype=float32
heights
=
height_width
[...,
0
]
widths
=
height_width
[...,
1
]
# Get the offsets of center points
new_offsets
=
tf
.
gather_nd
(
offset_predictions
,
combined_indices
)
offsets
=
tf
.
reshape
(
new_offsets
,
[
batch_size
,
num_boxes
,
2
])
# offsets are dtype=float32
y_offsets
=
offsets
[...,
0
]
x_offsets
=
offsets
[...,
1
]
y_indices
=
tf
.
cast
(
y_indices
,
dtype
=
heights
.
dtype
)
x_indices
=
tf
.
cast
(
x_indices
,
dtype
=
widths
.
dtype
)
detection_classes
=
channel_indices
+
self
.
_class_offset
ymin
=
y_indices
+
y_offsets
-
heights
/
2.0
xmin
=
x_indices
+
x_offsets
-
widths
/
2.0
ymax
=
y_indices
+
y_offsets
+
heights
/
2.0
xmax
=
x_indices
+
x_offsets
+
widths
/
2.0
ymin
=
tf
.
clip_by_value
(
ymin
,
0.
,
tf
.
cast
(
height
,
ymin
.
dtype
))
xmin
=
tf
.
clip_by_value
(
xmin
,
0.
,
tf
.
cast
(
width
,
xmin
.
dtype
))
ymax
=
tf
.
clip_by_value
(
ymax
,
0.
,
tf
.
cast
(
height
,
ymax
.
dtype
))
xmax
=
tf
.
clip_by_value
(
xmax
,
0.
,
tf
.
cast
(
width
,
xmax
.
dtype
))
boxes
=
tf
.
stack
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=
2
)
return
boxes
,
detection_classes
def
convert_strided_predictions_to_normalized_boxes
(
self
,
boxes
:
tf
.
Tensor
):
boxes
=
boxes
*
tf
.
cast
(
self
.
_net_down_scale
,
boxes
.
dtype
)
boxes
=
boxes
/
tf
.
cast
(
self
.
_input_image_dims
,
boxes
.
dtype
)
boxes
=
tf
.
clip_by_value
(
boxes
,
0.0
,
1.0
)
return
boxes
def
__call__
(
self
,
inputs
):
# Get heatmaps from decoded outputs via final hourglass stack output
all_ct_heatmaps
=
inputs
[
'ct_heatmaps'
]
all_ct_sizes
=
inputs
[
'ct_size'
]
all_ct_offsets
=
inputs
[
'ct_offset'
]
ct_heatmaps
=
all_ct_heatmaps
[
-
1
]
ct_sizes
=
all_ct_sizes
[
-
1
]
ct_offsets
=
all_ct_offsets
[
-
1
]
shape
=
tf
.
shape
(
ct_heatmaps
)
_
,
width
=
shape
[
1
],
shape
[
2
]
batch_size
,
num_channels
=
shape
[
0
],
shape
[
3
]
# Process heatmaps using 3x3 max pool and applying sigmoid
peaks
=
self
.
process_heatmap
(
feature_map
=
ct_heatmaps
,
kernel_size
=
self
.
_peak_extract_kernel_size
)
# Get top scores along with their x, y, and class
# Each has size [batch_size, k]
scores
,
y_indices
,
x_indices
,
channel_indices
=
self
.
get_top_k_peaks
(
feature_map_peaks
=
peaks
,
batch_size
=
batch_size
,
width
=
width
,
num_classes
=
num_channels
,
k
=
self
.
_max_detections
)
# Parse the score and indices into bounding boxes
boxes
,
classes
=
self
.
get_boxes
(
y_indices
=
y_indices
,
x_indices
=
x_indices
,
channel_indices
=
channel_indices
,
height_width_predictions
=
ct_sizes
,
offset_predictions
=
ct_offsets
,
num_boxes
=
self
.
_max_detections
)
# Normalize bounding boxes
boxes
=
self
.
convert_strided_predictions_to_normalized_boxes
(
boxes
)
# Apply nms
if
self
.
_use_nms
:
boxes
=
tf
.
expand_dims
(
boxes
,
axis
=-
2
)
multi_class_scores
=
tf
.
gather_nd
(
peaks
,
tf
.
stack
([
y_indices
,
x_indices
],
-
1
),
batch_dims
=
1
)
boxes
,
_
,
scores
=
nms_ops
.
nms
(
boxes
=
boxes
,
classes
=
multi_class_scores
,
confidence
=
scores
,
k
=
self
.
_max_detections
,
limit_pre_thresh
=
True
,
pre_nms_thresh
=
0.1
,
nms_thresh
=
0.4
)
num_det
=
tf
.
reduce_sum
(
tf
.
cast
(
scores
>
0
,
dtype
=
tf
.
int32
),
axis
=
1
)
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
[
self
.
_input_image_dims
,
self
.
_input_image_dims
])
return
{
'boxes'
:
boxes
,
'classes'
:
classes
,
'confidence'
:
scores
,
'num_detections'
:
num_det
}
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
config
=
{
'max_detections'
:
self
.
_max_detections
,
'peak_error'
:
self
.
_peak_error
,
'peak_extract_kernel_size'
:
self
.
_peak_extract_kernel_size
,
'class_offset'
:
self
.
_class_offset
,
'net_down_scale'
:
self
.
_net_down_scale
,
'input_image_dims'
:
self
.
_input_image_dims
,
'use_nms'
:
self
.
_use_nms
,
'nms_pre_thresh'
:
self
.
_nms_pre_thresh
,
'nms_thresh'
:
self
.
_nms_thresh
}
base_config
=
super
(
CenterNetDetectionGenerator
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
official/vision/beta/projects/centernet/ops/__init__.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/centernet/ops/box_list.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bounding Box List definition.
BoxList represents a list of bounding boxes as tensorflow
tensors, where each bounding box is represented as a row of 4 numbers,
[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes
within a given list correspond to a single image. See also
box_list_ops.py for common box related operations (such as area, iou, etc).
Optionally, users can add additional related fields (such as weights).
We assume the following things to be true about fields:
* they correspond to boxes in the box_list along the 0th dimension
* they have inferrable rank at graph construction time
* all dimensions except for possibly the 0th can be inferred
(i.e., not None) at graph construction time.
Some other notes:
* Following tensorflow conventions, we use height, width ordering,
and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
* Tensors are always provided as (flat) [N, 4] tensors.
"""
import
tensorflow
as
tf
def
_get_dim_as_int
(
dim
):
"""Utility to get v1 or v2 TensorShape dim as an int.
Args:
dim: The TensorShape dimension to get as an int
Returns:
None or an int.
"""
try
:
return
dim
.
value
except
AttributeError
:
return
dim
class
BoxList
(
object
):
"""Box collection."""
def
__init__
(
self
,
boxes
):
"""Constructs box collection.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data or if bbox data is not in
float32 format.
"""
if
len
(
boxes
.
get_shape
())
!=
2
or
boxes
.
get_shape
()[
-
1
]
!=
4
:
raise
ValueError
(
'Invalid dimensions for box data: {}'
.
format
(
boxes
.
shape
))
if
boxes
.
dtype
!=
tf
.
float32
:
raise
ValueError
(
'Invalid tensor type: should be tf.float32'
)
self
.
data
=
{
'boxes'
:
boxes
}
def
num_boxes
(
self
):
"""Returns number of boxes held in collection.
Returns:
a tensor representing the number of boxes held in the collection.
"""
return
tf
.
shape
(
self
.
data
[
'boxes'
])[
0
]
def
num_boxes_static
(
self
):
"""Returns number of boxes held in collection.
This number is inferred at graph construction time rather than run-time.
Returns:
Number of boxes held in collection (integer) or None if this is not
inferrable at graph construction time.
"""
return
_get_dim_as_int
(
self
.
data
[
'boxes'
].
get_shape
()[
0
])
def
get_all_fields
(
self
):
"""Returns all fields."""
return
self
.
data
.
keys
()
def
get_extra_fields
(
self
):
"""Returns all non-box fields (i.e., everything not named 'boxes')."""
return
[
k
for
k
in
self
.
data
.
keys
()
if
k
!=
'boxes'
]
def
add_field
(
self
,
field
,
field_data
):
"""Add field to box list.
This method can be used to add related box data such as
weights/labels, etc.
Args:
field: a string key to access the data via `get`
field_data: a tensor containing the data to store in the BoxList
"""
self
.
data
[
field
]
=
field_data
def
has_field
(
self
,
field
):
return
field
in
self
.
data
def
get
(
self
):
"""Convenience function for accessing box coordinates.
Returns:
a tensor with shape [N, 4] representing box coordinates.
"""
return
self
.
get_field
(
'boxes'
)
def
set
(
self
,
boxes
):
"""Convenience function for setting box coordinates.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data
"""
if
len
(
boxes
.
get_shape
())
!=
2
or
boxes
.
get_shape
()[
-
1
]
!=
4
:
raise
ValueError
(
'Invalid dimensions for box data.'
)
self
.
data
[
'boxes'
]
=
boxes
def
get_field
(
self
,
field
):
"""Accesses a box collection and associated fields.
This function returns specified field with object; if no field is specified,
it returns the box coordinates.
Args:
field: this optional string parameter can be used to specify
a related field to be accessed.
Returns:
a tensor representing the box collection or an associated field.
Raises:
ValueError: if invalid field
"""
if
not
self
.
has_field
(
field
):
raise
ValueError
(
'field '
+
str
(
field
)
+
' does not exist'
)
return
self
.
data
[
field
]
def
set_field
(
self
,
field
,
value
):
"""Sets the value of a field.
Updates the field of a box_list with a given value.
Args:
field: (string) name of the field to set value.
value: the value to assign to the field.
Raises:
ValueError: if the box_list does not have specified field.
"""
if
not
self
.
has_field
(
field
):
raise
ValueError
(
'field %s does not exist'
%
field
)
self
.
data
[
field
]
=
value
def
get_center_coordinates_and_sizes
(
self
):
"""Computes the center coordinates, height and width of the boxes.
Returns:
a list of 4 1-D tensors [ycenter, xcenter, height, width].
"""
with
tf
.
name_scope
(
'get_center_coordinates_and_sizes'
):
box_corners
=
self
.
get
()
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
unstack
(
tf
.
transpose
(
box_corners
))
width
=
xmax
-
xmin
height
=
ymax
-
ymin
ycenter
=
ymin
+
height
/
2.
xcenter
=
xmin
+
width
/
2.
return
[
ycenter
,
xcenter
,
height
,
width
]
def
transpose_coordinates
(
self
):
"""Transpose the coordinate representation in a boxlist."""
with
tf
.
name_scope
(
'transpose_coordinates'
):
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
self
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
self
.
set
(
tf
.
concat
([
x_min
,
y_min
,
x_max
,
y_max
],
1
))
def
as_tensor_dict
(
self
,
fields
=
None
):
"""Retrieves specified fields as a dictionary of tensors.
Args:
fields: (optional) list of fields to return in the dictionary.
If None (default), all fields are returned.
Returns:
tensor_dict: A dictionary of tensors specified by fields.
Raises:
ValueError: if specified field is not contained in boxlist.
"""
tensor_dict
=
{}
if
fields
is
None
:
fields
=
self
.
get_all_fields
()
for
field
in
fields
:
if
not
self
.
has_field
(
field
):
raise
ValueError
(
'boxlist must contain all specified fields'
)
tensor_dict
[
field
]
=
self
.
get_field
(
field
)
return
tensor_dict
official/vision/beta/projects/centernet/ops/box_list_ops.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bounding Box List operations."""
import
tensorflow
as
tf
from
official.vision.beta.ops
import
sampling_ops
from
official.vision.beta.projects.centernet.ops
import
box_list
def
_copy_extra_fields
(
boxlist_to_copy_to
,
boxlist_to_copy_from
):
"""Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
Args:
boxlist_to_copy_to: BoxList to which extra fields are copied.
boxlist_to_copy_from: BoxList from which fields are copied.
Returns:
boxlist_to_copy_to with extra fields.
"""
for
field
in
boxlist_to_copy_from
.
get_extra_fields
():
boxlist_to_copy_to
.
add_field
(
field
,
boxlist_to_copy_from
.
get_field
(
field
))
return
boxlist_to_copy_to
def
scale
(
boxlist
,
y_scale
,
x_scale
):
"""scale box coordinates in x and y dimensions.
Args:
boxlist: BoxList holding N boxes
y_scale: (float) scalar tensor
x_scale: (float) scalar tensor
Returns:
boxlist: BoxList holding N boxes
"""
with
tf
.
name_scope
(
'Scale'
):
y_scale
=
tf
.
cast
(
y_scale
,
tf
.
float32
)
x_scale
=
tf
.
cast
(
x_scale
,
tf
.
float32
)
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
boxlist
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
y_min
=
y_scale
*
y_min
y_max
=
y_scale
*
y_max
x_min
=
x_scale
*
x_min
x_max
=
x_scale
*
x_max
scaled_boxlist
=
box_list
.
BoxList
(
tf
.
concat
([
y_min
,
x_min
,
y_max
,
x_max
],
1
))
return
_copy_extra_fields
(
scaled_boxlist
,
boxlist
)
def
area
(
boxlist
):
"""Computes area of boxes.
Args:
boxlist: BoxList holding N boxes
Returns:
a tensor with shape [N] representing box areas.
"""
with
tf
.
name_scope
(
'Area'
):
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
boxlist
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
return
tf
.
squeeze
((
y_max
-
y_min
)
*
(
x_max
-
x_min
),
[
1
])
def
change_coordinate_frame
(
boxlist
,
window
):
"""Change coordinate frame of the boxlist to be relative to window's frame.
Given a window of the form [ymin, xmin, ymax, xmax],
changes bounding box coordinates from boxlist to be relative to this window
(e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
An example use case is data augmentation: where we are given groundtruth
boxes (boxlist) and would like to randomly crop the image to some
window (window). In this case we need to change the coordinate frame of
each groundtruth box to be relative to this new window.
Args:
boxlist: A BoxList object holding N boxes.
window: A rank 1 tensor [4].
Returns:
Returns a BoxList object with N boxes.
"""
with
tf
.
name_scope
(
'ChangeCoordinateFrame'
):
win_height
=
window
[
2
]
-
window
[
0
]
win_width
=
window
[
3
]
-
window
[
1
]
boxlist_new
=
scale
(
box_list
.
BoxList
(
boxlist
.
get
()
-
[
window
[
0
],
window
[
1
],
window
[
0
],
window
[
1
]]),
1.0
/
win_height
,
1.0
/
win_width
)
boxlist_new
=
_copy_extra_fields
(
boxlist_new
,
boxlist
)
return
boxlist_new
def
matmul_gather_on_zeroth_axis
(
params
,
indices
):
"""Matrix multiplication based implementation of tf.gather on zeroth axis.
Args:
params: A float32 Tensor. The tensor from which to gather values.
Must be at least rank 1.
indices: A Tensor. Must be one of the following types: int32, int64.
Must be in range [0, params.shape[0])
Returns:
A Tensor. Has the same type as params. Values from params gathered
from indices given by indices, with shape indices.shape + params.shape[1:].
"""
with
tf
.
name_scope
(
'MatMulGather'
):
params_shape
=
sampling_ops
.
combined_static_and_dynamic_shape
(
params
)
indices_shape
=
sampling_ops
.
combined_static_and_dynamic_shape
(
indices
)
params2d
=
tf
.
reshape
(
params
,
[
params_shape
[
0
],
-
1
])
indicator_matrix
=
tf
.
one_hot
(
indices
,
params_shape
[
0
])
gathered_result_flattened
=
tf
.
matmul
(
indicator_matrix
,
params2d
)
return
tf
.
reshape
(
gathered_result_flattened
,
tf
.
stack
(
indices_shape
+
params_shape
[
1
:]))
def
gather
(
boxlist
,
indices
,
fields
=
None
,
use_static_shapes
=
False
):
"""Gather boxes from BoxList according to indices and return new BoxList.
By default, `gather` returns boxes corresponding to the input index list, as
well as all additional fields stored in the boxlist (indexing into the
first dimension). However one can optionally only gather from a
subset of fields.
Args:
boxlist: BoxList holding N boxes
indices: a rank-1 tensor of type int32 / int64
fields: (optional) list of fields to also gather from. If None (default),
all fields are gathered from. Pass an empty fields list to only gather
the box coordinates.
use_static_shapes: Whether to use an implementation with static shape
gurantees.
Returns:
subboxlist: a BoxList corresponding to the subset of the input BoxList
specified by indices
Raises:
ValueError: if specified field is not contained in boxlist or if the
indices are not of type int32
"""
with
tf
.
name_scope
(
'Gather'
):
if
len
(
indices
.
shape
.
as_list
())
!=
1
:
raise
ValueError
(
'indices should have rank 1'
)
if
indices
.
dtype
!=
tf
.
int32
and
indices
.
dtype
!=
tf
.
int64
:
raise
ValueError
(
'indices should be an int32 / int64 tensor'
)
gather_op
=
tf
.
gather
if
use_static_shapes
:
gather_op
=
matmul_gather_on_zeroth_axis
subboxlist
=
box_list
.
BoxList
(
gather_op
(
boxlist
.
get
(),
indices
))
if
fields
is
None
:
fields
=
boxlist
.
get_extra_fields
()
fields
+=
[
'boxes'
]
for
field
in
fields
:
if
not
boxlist
.
has_field
(
field
):
raise
ValueError
(
'boxlist must contain all specified fields'
)
subfieldlist
=
gather_op
(
boxlist
.
get_field
(
field
),
indices
)
subboxlist
.
add_field
(
field
,
subfieldlist
)
return
subboxlist
def
prune_completely_outside_window
(
boxlist
,
window
):
"""Prunes bounding boxes that fall completely outside of the given window.
The function clip_to_window prunes bounding boxes that fall
completely outside the window, but also clips any bounding boxes that
partially overflow. This function does not clip partially overflowing boxes.
Args:
boxlist: a BoxList holding M_in boxes.
window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
of the window
Returns:
pruned_boxlist: a new BoxList with all bounding boxes partially or fully in
the window.
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor.
"""
with
tf
.
name_scope
(
'PruneCompleteleyOutsideWindow'
):
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
boxlist
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
win_y_min
,
win_x_min
,
win_y_max
,
win_x_max
=
tf
.
unstack
(
window
)
coordinate_violations
=
tf
.
concat
([
tf
.
greater_equal
(
y_min
,
win_y_max
),
tf
.
greater_equal
(
x_min
,
win_x_max
),
tf
.
less_equal
(
y_max
,
win_y_min
),
tf
.
less_equal
(
x_max
,
win_x_min
)
],
1
)
valid_indices
=
tf
.
reshape
(
tf
.
where
(
tf
.
logical_not
(
tf
.
reduce_any
(
coordinate_violations
,
1
))),
[
-
1
])
return
gather
(
boxlist
,
valid_indices
),
valid_indices
def
clip_to_window
(
boxlist
,
window
,
filter_nonoverlapping
=
True
):
"""Clip bounding boxes to a window.
This op clips any input bounding boxes (represented by bounding box
corners) to a window, optionally filtering out boxes that do not
overlap at all with the window.
Args:
boxlist: BoxList holding M_in boxes
window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
window to which the op should clip boxes.
filter_nonoverlapping: whether to filter out boxes that do not overlap at
all with the window.
Returns:
a BoxList holding M_out boxes where M_out <= M_in
"""
with
tf
.
name_scope
(
'ClipToWindow'
):
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
boxlist
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
win_y_min
=
window
[
0
]
win_x_min
=
window
[
1
]
win_y_max
=
window
[
2
]
win_x_max
=
window
[
3
]
y_min_clipped
=
tf
.
maximum
(
tf
.
minimum
(
y_min
,
win_y_max
),
win_y_min
)
y_max_clipped
=
tf
.
maximum
(
tf
.
minimum
(
y_max
,
win_y_max
),
win_y_min
)
x_min_clipped
=
tf
.
maximum
(
tf
.
minimum
(
x_min
,
win_x_max
),
win_x_min
)
x_max_clipped
=
tf
.
maximum
(
tf
.
minimum
(
x_max
,
win_x_max
),
win_x_min
)
clipped
=
box_list
.
BoxList
(
tf
.
concat
([
y_min_clipped
,
x_min_clipped
,
y_max_clipped
,
x_max_clipped
],
1
))
clipped
=
_copy_extra_fields
(
clipped
,
boxlist
)
if
filter_nonoverlapping
:
areas
=
area
(
clipped
)
nonzero_area_indices
=
tf
.
cast
(
tf
.
reshape
(
tf
.
where
(
tf
.
greater
(
areas
,
0.0
)),
[
-
1
]),
tf
.
int32
)
clipped
=
gather
(
clipped
,
nonzero_area_indices
)
return
clipped
def
height_width
(
boxlist
):
"""Computes height and width of boxes in boxlist.
Args:
boxlist: BoxList holding N boxes
Returns:
Height: A tensor with shape [N] representing box heights.
Width: A tensor with shape [N] representing box widths.
"""
with
tf
.
name_scope
(
'HeightWidth'
):
y_min
,
x_min
,
y_max
,
x_max
=
tf
.
split
(
value
=
boxlist
.
get
(),
num_or_size_splits
=
4
,
axis
=
1
)
return
tf
.
squeeze
(
y_max
-
y_min
,
[
1
]),
tf
.
squeeze
(
x_max
-
x_min
,
[
1
])
def
prune_small_boxes
(
boxlist
,
min_side
):
"""Prunes small boxes in the boxlist which have a side smaller than min_side.
Args:
boxlist: BoxList holding N boxes.
min_side: Minimum width AND height of box to survive pruning.
Returns:
A pruned boxlist.
"""
with
tf
.
name_scope
(
'PruneSmallBoxes'
):
height
,
width
=
height_width
(
boxlist
)
is_valid
=
tf
.
logical_and
(
tf
.
greater_equal
(
width
,
min_side
),
tf
.
greater_equal
(
height
,
min_side
))
return
gather
(
boxlist
,
tf
.
reshape
(
tf
.
where
(
is_valid
),
[
-
1
]))
def
assert_or_prune_invalid_boxes
(
boxes
):
"""Makes sure boxes have valid sizes (ymax >= ymin, xmax >= xmin).
When the hardware supports assertions, the function raises an error when
boxes have an invalid size. If assertions are not supported (e.g. on TPU),
boxes with invalid sizes are filtered out.
Args:
boxes: float tensor of shape [num_boxes, 4]
Returns:
boxes: float tensor of shape [num_valid_boxes, 4] with invalid boxes
filtered out.
Raises:
tf.errors.InvalidArgumentError: When we detect boxes with invalid size.
This is not supported on TPUs.
"""
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
split
(
boxes
,
num_or_size_splits
=
4
,
axis
=
1
)
height_check
=
tf
.
Assert
(
tf
.
reduce_all
(
ymax
>=
ymin
),
[
ymin
,
ymax
])
width_check
=
tf
.
Assert
(
tf
.
reduce_all
(
xmax
>=
xmin
),
[
xmin
,
xmax
])
with
tf
.
control_dependencies
([
height_check
,
width_check
]):
boxes_tensor
=
tf
.
concat
([
ymin
,
xmin
,
ymax
,
xmax
],
axis
=
1
)
boxlist
=
box_list
.
BoxList
(
boxes_tensor
)
boxlist
=
prune_small_boxes
(
boxlist
,
0
)
return
boxlist
.
get
()
def
to_absolute_coordinates
(
boxlist
,
height
,
width
,
check_range
=
True
,
maximum_normalized_coordinate
=
1.1
):
"""Converts normalized box coordinates to absolute pixel coordinates.
This function raises an assertion failed error when the maximum box coordinate
value is larger than maximum_normalized_coordinate (in which case coordinates
are already absolute).
Args:
boxlist: BoxList with coordinates in range [0, 1].
height: Maximum value for height of absolute box coordinates.
width: Maximum value for width of absolute box coordinates.
check_range: If True, checks if the coordinates are normalized or not.
maximum_normalized_coordinate: Maximum coordinate value to be considered
as normalized, default to 1.1.
Returns:
boxlist with absolute coordinates in terms of the image size.
"""
with
tf
.
name_scope
(
'ToAbsoluteCoordinates'
):
height
=
tf
.
cast
(
height
,
tf
.
float32
)
width
=
tf
.
cast
(
width
,
tf
.
float32
)
# Ensure range of input boxes is correct.
if
check_range
:
box_maximum
=
tf
.
reduce_max
(
boxlist
.
get
())
max_assert
=
tf
.
Assert
(
tf
.
greater_equal
(
maximum_normalized_coordinate
,
box_maximum
),
[
'maximum box coordinate value is larger '
'than %f: '
%
maximum_normalized_coordinate
,
box_maximum
])
with
tf
.
control_dependencies
([
max_assert
]):
width
=
tf
.
identity
(
width
)
return
scale
(
boxlist
,
height
,
width
)
official/vision/beta/projects/centernet/ops/loss_ops.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Operations for compute losses for centernet."""
import
tensorflow
as
tf
from
official.vision.beta.ops
import
sampling_ops
def
_get_shape
(
tensor
,
num_dims
):
assert
len
(
tensor
.
shape
.
as_list
())
==
num_dims
return
sampling_ops
.
combined_static_and_dynamic_shape
(
tensor
)
def
flatten_spatial_dimensions
(
batch_images
):
# pylint: disable=unbalanced-tuple-unpacking
batch_size
,
height
,
width
,
channels
=
_get_shape
(
batch_images
,
4
)
return
tf
.
reshape
(
batch_images
,
[
batch_size
,
height
*
width
,
channels
])
def
multi_range
(
limit
,
value_repetitions
=
1
,
range_repetitions
=
1
,
dtype
=
tf
.
int32
):
"""Creates a sequence with optional value duplication and range repetition.
As an example (see the Args section for more details),
_multi_range(limit=2, value_repetitions=3, range_repetitions=4) returns:
[0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
NOTE: Repurposed from Google OD API.
Args:
limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive.
value_repetitions: Integer. The number of times a value in the sequence is
repeated. With value_repetitions=3, the result is [0, 0, 0, 1, 1, 1, ..].
range_repetitions: Integer. The number of times the range is repeated. With
range_repetitions=3, the result is [0, 1, 2, .., 0, 1, 2, ..].
dtype: The type of the elements of the resulting tensor.
Returns:
A 1-D tensor of type `dtype` and size
[`limit` * `value_repetitions` * `range_repetitions`] that contains the
specified range with given repetitions.
"""
return
tf
.
reshape
(
tf
.
tile
(
tf
.
expand_dims
(
tf
.
range
(
limit
,
dtype
=
dtype
),
axis
=-
1
),
multiples
=
[
range_repetitions
,
value_repetitions
]),
[
-
1
])
def
add_batch_to_indices
(
indices
):
shape
=
tf
.
shape
(
indices
)
batch_size
=
shape
[
0
]
num_instances
=
shape
[
1
]
batch_range
=
multi_range
(
limit
=
batch_size
,
value_repetitions
=
num_instances
)
batch_range
=
tf
.
reshape
(
batch_range
,
shape
=
(
batch_size
,
num_instances
,
1
))
return
tf
.
concat
([
batch_range
,
indices
],
axis
=-
1
)
def
get_num_instances_from_weights
(
gt_weights_list
):
"""Computes the number of instances/boxes from the weights in a batch.
Args:
gt_weights_list: A list of float tensors with shape
[max_num_instances] representing whether there is an actual instance in
the image (with non-zero value) or is padded to match the
max_num_instances (with value 0.0). The list represents the batch
dimension.
Returns:
A scalar integer tensor incidating how many instances/boxes are in the
images in the batch. Note that this function is usually used to normalize
the loss so the minimum return value is 1 to avoid weird behavior.
"""
# This can execute in graph mode
gt_weights_list
=
tf
.
convert_to_tensor
(
gt_weights_list
,
dtype
=
gt_weights_list
[
0
].
dtype
)
num_instances
=
tf
.
map_fn
(
fn
=
lambda
x
:
tf
.
math
.
count_nonzero
(
x
,
dtype
=
gt_weights_list
[
0
].
dtype
),
elems
=
gt_weights_list
)
num_instances
=
tf
.
reduce_sum
(
num_instances
)
num_instances
=
tf
.
maximum
(
num_instances
,
1
)
return
num_instances
def
get_batch_predictions_from_indices
(
batch_predictions
,
indices
):
"""Gets the values of predictions in a batch at the given indices.
The indices are expected to come from the offset targets generation functions
in this library. The returned value is intended to be used inside a loss
function.
Args:
batch_predictions: A tensor of shape [batch_size, height, width, channels]
or [batch_size, height, width, class, channels] for class-specific
features (e.g. keypoint joint offsets).
indices: A tensor of shape [num_instances, 3] for single class features or
[num_instances, 4] for multiple classes features.
Returns:
values: A tensor of shape [num_instances, channels] holding the predicted
values at the given indices.
"""
return
tf
.
gather_nd
(
batch_predictions
,
indices
)
def
get_valid_anchor_weights_in_flattened_image
(
true_image_shapes
,
height
,
width
):
"""Computes valid anchor weights for an image assuming pixels to be flattened.
This function is useful when we only want to penalize valid areas in the
image in the case when padding is used. The function assumes that the loss
function will be applied after flattening the spatial dimensions and returns
anchor weights accordingly.
Args:
true_image_shapes: An integer tensor of shape [batch_size, 3] representing
the true image shape (without padding) for each sample in the batch.
height: height of the prediction from the network.
width: width of the prediction from the network.
Returns:
valid_anchor_weights: a float tensor of shape [batch_size, height * width]
with 1s in locations where the spatial coordinates fall within the height
and width in true_image_shapes.
"""
indices
=
tf
.
reshape
(
tf
.
range
(
height
*
width
),
[
1
,
-
1
])
batch_size
=
tf
.
shape
(
true_image_shapes
)[
0
]
batch_indices
=
tf
.
ones
((
batch_size
,
1
),
dtype
=
tf
.
int32
)
*
indices
y_coords
,
x_coords
,
_
=
get_row_col_channel_indices_from_flattened_indices
(
batch_indices
,
width
,
1
)
max_y
,
max_x
=
true_image_shapes
[:,
0
],
true_image_shapes
[:,
1
]
max_x
=
tf
.
cast
(
tf
.
expand_dims
(
max_x
,
1
),
tf
.
float32
)
max_y
=
tf
.
cast
(
tf
.
expand_dims
(
max_y
,
1
),
tf
.
float32
)
x_coords
=
tf
.
cast
(
x_coords
,
tf
.
float32
)
y_coords
=
tf
.
cast
(
y_coords
,
tf
.
float32
)
valid_mask
=
tf
.
math
.
logical_and
(
x_coords
<
max_x
,
y_coords
<
max_y
)
return
tf
.
cast
(
valid_mask
,
tf
.
float32
)
def
get_row_col_channel_indices_from_flattened_indices
(
indices
:
int
,
num_cols
:
int
,
num_channels
:
int
):
"""Computes row, column and channel indices from flattened indices.
NOTE: Repurposed from Google OD API.
Args:
indices: An `int` tensor of any shape holding the indices in the flattened
space.
num_cols: `int`, number of columns in the image (width).
num_channels: `int`, number of channels in the image.
Returns:
row_indices: The row indices corresponding to each of the input indices.
Same shape as indices.
col_indices: The column indices corresponding to each of the input indices.
Same shape as indices.
channel_indices. The channel indices corresponding to each of the input
indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
# all inputs and outputs are dtype int32
row_indices
=
(
indices
//
num_channels
)
//
num_cols
col_indices
=
(
indices
//
num_channels
)
-
row_indices
*
num_cols
channel_indices_temp
=
indices
//
num_channels
channel_indices
=
indices
-
channel_indices_temp
*
num_channels
return
row_indices
,
col_indices
,
channel_indices
official/vision/beta/projects/centernet/ops/nms_ops.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""nms computation."""
import
tensorflow
as
tf
from
official.vision.beta.projects.yolo.ops
import
box_ops
NMS_TILE_SIZE
=
512
# pylint: disable=missing-function-docstring
def
aggregated_comparative_iou
(
boxes1
,
boxes2
=
None
,
iou_type
=
0
):
k
=
tf
.
shape
(
boxes1
)[
-
2
]
boxes1
=
tf
.
expand_dims
(
boxes1
,
axis
=-
2
)
boxes1
=
tf
.
tile
(
boxes1
,
[
1
,
1
,
k
,
1
])
if
boxes2
is
not
None
:
boxes2
=
tf
.
expand_dims
(
boxes2
,
axis
=-
2
)
boxes2
=
tf
.
tile
(
boxes2
,
[
1
,
1
,
k
,
1
])
boxes2
=
tf
.
transpose
(
boxes2
,
perm
=
(
0
,
2
,
1
,
3
))
else
:
boxes2
=
tf
.
transpose
(
boxes1
,
perm
=
(
0
,
2
,
1
,
3
))
if
iou_type
==
0
:
# diou
_
,
iou
=
box_ops
.
compute_diou
(
boxes1
,
boxes2
)
elif
iou_type
==
1
:
# giou
_
,
iou
=
box_ops
.
compute_giou
(
boxes1
,
boxes2
)
else
:
iou
=
box_ops
.
compute_iou
(
boxes1
,
boxes2
,
yxyx
=
True
)
return
iou
# pylint: disable=missing-function-docstring
def
sort_drop
(
objectness
,
box
,
classificationsi
,
k
):
objectness
,
ind
=
tf
.
math
.
top_k
(
objectness
,
k
=
k
)
ind_m
=
tf
.
ones_like
(
ind
)
*
tf
.
expand_dims
(
tf
.
range
(
0
,
tf
.
shape
(
objectness
)[
0
]),
axis
=-
1
)
bind
=
tf
.
stack
([
tf
.
reshape
(
ind_m
,
[
-
1
]),
tf
.
reshape
(
ind
,
[
-
1
])],
axis
=-
1
)
box
=
tf
.
gather_nd
(
box
,
bind
)
classifications
=
tf
.
gather_nd
(
classificationsi
,
bind
)
bsize
=
tf
.
shape
(
ind
)[
0
]
box
=
tf
.
reshape
(
box
,
[
bsize
,
k
,
-
1
])
classifications
=
tf
.
reshape
(
classifications
,
[
bsize
,
k
,
-
1
])
return
objectness
,
box
,
classifications
# pylint: disable=missing-function-docstring
def
segment_nms
(
boxes
,
classes
,
confidence
,
k
,
iou_thresh
):
mrange
=
tf
.
range
(
k
)
mask_x
=
tf
.
tile
(
tf
.
transpose
(
tf
.
expand_dims
(
mrange
,
axis
=-
1
),
perm
=
[
1
,
0
]),
[
k
,
1
])
mask_y
=
tf
.
tile
(
tf
.
expand_dims
(
mrange
,
axis
=-
1
),
[
1
,
k
])
mask_diag
=
tf
.
expand_dims
(
mask_x
>
mask_y
,
axis
=
0
)
iou
=
aggregated_comparative_iou
(
boxes
,
iou_type
=
0
)
# duplicate boxes
iou_mask
=
iou
>=
iou_thresh
iou_mask
=
tf
.
logical_and
(
mask_diag
,
iou_mask
)
iou
*=
tf
.
cast
(
iou_mask
,
iou
.
dtype
)
can_suppress_others
=
1
-
tf
.
cast
(
tf
.
reduce_any
(
iou_mask
,
axis
=-
2
),
boxes
.
dtype
)
raw
=
tf
.
cast
(
can_suppress_others
,
boxes
.
dtype
)
boxes
*=
tf
.
expand_dims
(
raw
,
axis
=-
1
)
confidence
*=
tf
.
cast
(
raw
,
confidence
.
dtype
)
classes
*=
tf
.
cast
(
raw
,
classes
.
dtype
)
return
boxes
,
classes
,
confidence
# pylint: disable=missing-function-docstring
def
nms
(
boxes
,
classes
,
confidence
,
k
,
pre_nms_thresh
,
nms_thresh
,
limit_pre_thresh
=
False
,
use_classes
=
True
):
if
limit_pre_thresh
:
confidence
,
boxes
,
classes
=
sort_drop
(
confidence
,
boxes
,
classes
,
k
)
mask
=
tf
.
fill
(
tf
.
shape
(
confidence
),
tf
.
cast
(
pre_nms_thresh
,
dtype
=
confidence
.
dtype
))
mask
=
tf
.
math
.
ceil
(
tf
.
nn
.
relu
(
confidence
-
mask
))
confidence
=
confidence
*
mask
mask
=
tf
.
expand_dims
(
mask
,
axis
=-
1
)
boxes
=
boxes
*
mask
classes
=
classes
*
mask
if
use_classes
:
confidence
=
tf
.
reduce_max
(
classes
,
axis
=-
1
)
confidence
,
boxes
,
classes
=
sort_drop
(
confidence
,
boxes
,
classes
,
k
)
classes
=
tf
.
cast
(
tf
.
argmax
(
classes
,
axis
=-
1
),
tf
.
float32
)
boxes
,
classes
,
confidence
=
segment_nms
(
boxes
,
classes
,
confidence
,
k
,
nms_thresh
)
confidence
,
boxes
,
classes
=
sort_drop
(
confidence
,
boxes
,
classes
,
k
)
classes
=
tf
.
squeeze
(
classes
,
axis
=-
1
)
return
boxes
,
classes
,
confidence
official/vision/beta/projects/centernet/ops/preprocess_ops.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing ops imported from OD API."""
import
functools
import
tensorflow
as
tf
from
official.vision.beta.projects.centernet.ops
import
box_list
from
official.vision.beta.projects.centernet.ops
import
box_list_ops
def
_get_or_create_preprocess_rand_vars
(
generator_func
,
function_id
,
preprocess_vars_cache
,
key
=
''
):
"""Returns a tensor stored in preprocess_vars_cache or using generator_func.
If the tensor was previously generated and appears in the PreprocessorCache,
the previously generated tensor will be returned. Otherwise, a new tensor
is generated using generator_func and stored in the cache.
Args:
generator_func: A 0-argument function that generates a tensor.
function_id: identifier for the preprocessing function used.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
key: identifier for the variable stored.
Returns:
The generated tensor.
"""
if
preprocess_vars_cache
is
not
None
:
var
=
preprocess_vars_cache
.
get
(
function_id
,
key
)
if
var
is
None
:
var
=
generator_func
()
preprocess_vars_cache
.
update
(
function_id
,
key
,
var
)
else
:
var
=
generator_func
()
return
var
def
_random_integer
(
minval
,
maxval
,
seed
):
"""Returns a random 0-D tensor between minval and maxval.
Args:
minval: minimum value of the random tensor.
maxval: maximum value of the random tensor.
seed: random seed.
Returns:
A random 0-D tensor between minval and maxval.
"""
return
tf
.
random
.
uniform
(
[],
minval
=
minval
,
maxval
=
maxval
,
dtype
=
tf
.
int32
,
seed
=
seed
)
def
_get_crop_border
(
border
,
size
):
"""Get the border of cropping."""
border
=
tf
.
cast
(
border
,
tf
.
float32
)
size
=
tf
.
cast
(
size
,
tf
.
float32
)
i
=
tf
.
math
.
ceil
(
tf
.
math
.
log
(
2.0
*
border
/
size
)
/
tf
.
math
.
log
(
2.0
))
divisor
=
tf
.
pow
(
2.0
,
i
)
divisor
=
tf
.
clip_by_value
(
divisor
,
1
,
border
)
divisor
=
tf
.
cast
(
divisor
,
tf
.
int32
)
return
tf
.
cast
(
border
,
tf
.
int32
)
//
divisor
def
random_square_crop_by_scale
(
image
,
boxes
,
labels
,
max_border
=
128
,
scale_min
=
0.6
,
scale_max
=
1.3
,
num_scales
=
8
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly crop a square in proportion to scale and image size.
Extract a square sized crop from an image whose side length is sampled by
randomly scaling the maximum spatial dimension of the image. If part of
the crop falls outside the image, it is filled with zeros.
The augmentation is borrowed from [1]
[1]: https://arxiv.org/abs/1904.07850
Args:
image: rank 3 float32 tensor containing 1 image ->
[height, width, channels].
boxes: rank 2 float32 tensor containing the bounding boxes -> [N, 4].
Boxes are in normalized form meaning their coordinates vary
between [0, 1]. Each row is in the form of [ymin, xmin, ymax, xmax].
Boxes on the crop boundary are clipped to the boundary and boxes
falling outside the crop are ignored.
labels: rank 1 int32 tensor containing the object classes.
max_border: The maximum size of the border. The border defines distance in
pixels to the image boundaries that will not be considered as a center of
a crop. To make sure that the border does not go over the center of the
image, we chose the border value by computing the minimum k, such that
(max_border / (2**k)) < image_dimension/2.
scale_min: float, the minimum value for scale.
scale_max: float, the maximum value for scale.
num_scales: int, the number of discrete scale values to sample between
[scale_min, scale_max]
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same rank as input image.
boxes: boxes which is the same rank as input boxes.
Boxes are in normalized form.
labels: new labels.
"""
img_shape
=
tf
.
shape
(
image
)
height
,
width
=
img_shape
[
0
],
img_shape
[
1
]
scales
=
tf
.
linspace
(
scale_min
,
scale_max
,
num_scales
)
scale
=
_get_or_create_preprocess_rand_vars
(
lambda
:
scales
[
_random_integer
(
0
,
num_scales
,
seed
)],
'square_crop_scale'
,
preprocess_vars_cache
,
'scale'
)
image_size
=
scale
*
tf
.
cast
(
tf
.
maximum
(
height
,
width
),
tf
.
float32
)
image_size
=
tf
.
cast
(
image_size
,
tf
.
int32
)
h_border
=
_get_crop_border
(
max_border
,
height
)
w_border
=
_get_crop_border
(
max_border
,
width
)
def
y_function
():
y
=
_random_integer
(
h_border
,
tf
.
cast
(
height
,
tf
.
int32
)
-
h_border
+
1
,
seed
)
return
y
def
x_function
():
x
=
_random_integer
(
w_border
,
tf
.
cast
(
width
,
tf
.
int32
)
-
w_border
+
1
,
seed
)
return
x
y_center
=
_get_or_create_preprocess_rand_vars
(
y_function
,
'square_crop_scale'
,
preprocess_vars_cache
,
'y_center'
)
x_center
=
_get_or_create_preprocess_rand_vars
(
x_function
,
'square_crop_scale'
,
preprocess_vars_cache
,
'x_center'
)
half_size
=
tf
.
cast
(
image_size
/
2
,
tf
.
int32
)
crop_ymin
,
crop_ymax
=
y_center
-
half_size
,
y_center
+
half_size
crop_xmin
,
crop_xmax
=
x_center
-
half_size
,
x_center
+
half_size
ymin
=
tf
.
maximum
(
crop_ymin
,
0
)
xmin
=
tf
.
maximum
(
crop_xmin
,
0
)
ymax
=
tf
.
minimum
(
crop_ymax
,
height
-
1
)
xmax
=
tf
.
minimum
(
crop_xmax
,
width
-
1
)
cropped_image
=
image
[
ymin
:
ymax
,
xmin
:
xmax
]
offset_y
=
tf
.
maximum
(
0
,
ymin
-
crop_ymin
)
offset_x
=
tf
.
maximum
(
0
,
xmin
-
crop_xmin
)
oy_i
=
offset_y
ox_i
=
offset_x
output_image
=
tf
.
image
.
pad_to_bounding_box
(
cropped_image
,
offset_height
=
oy_i
,
offset_width
=
ox_i
,
target_height
=
image_size
,
target_width
=
image_size
)
if
ymin
==
0
:
# We might be padding the image.
box_ymin
=
-
offset_y
else
:
box_ymin
=
crop_ymin
if
xmin
==
0
:
# We might be padding the image.
box_xmin
=
-
offset_x
else
:
box_xmin
=
crop_xmin
box_ymax
=
box_ymin
+
image_size
box_xmax
=
box_xmin
+
image_size
image_box
=
[
box_ymin
/
height
,
box_xmin
/
width
,
box_ymax
/
height
,
box_xmax
/
width
]
boxlist
=
box_list
.
BoxList
(
boxes
)
boxlist
=
box_list_ops
.
change_coordinate_frame
(
boxlist
,
image_box
)
boxlist
,
indices
=
box_list_ops
.
prune_completely_outside_window
(
boxlist
,
[
0.0
,
0.0
,
1.0
,
1.0
])
boxlist
=
box_list_ops
.
clip_to_window
(
boxlist
,
[
0.0
,
0.0
,
1.0
,
1.0
],
filter_nonoverlapping
=
False
)
return_values
=
[
output_image
,
boxlist
.
get
(),
tf
.
gather
(
labels
,
indices
)]
return
return_values
def
resize_to_range
(
image
,
masks
=
None
,
min_dimension
=
None
,
max_dimension
=
None
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
,
pad_to_max_dimension
=
False
,
per_channel_pad_value
=
(
0
,
0
,
0
)):
"""Resizes an image so its dimensions are within the provided value.
The output size can be described by two cases:
1. If the image can be rescaled so its minimum dimension is equal to the
provided value without the other dimension exceeding max_dimension,
then do so.
2. Otherwise, resize so the largest dimension is equal to max_dimension.
Args:
image: A 3D tensor of shape [height, width, channels]
masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks.
min_dimension: (optional) (scalar) desired size of the smaller image
dimension.
max_dimension: (optional) (scalar) maximum allowed size
of the larger image dimension.
method: (optional) interpolation method used in resizing. Defaults to
BILINEAR.
pad_to_max_dimension: Whether to resize the image and pad it with zeros
so the resulting image is of the spatial size
[max_dimension, max_dimension]. If masks are included they are padded
similarly.
per_channel_pad_value: A tuple of per-channel scalar value to use for
padding. By default pads zeros.
Returns:
Note that the position of the resized_image_shape changes based on whether
masks are present.
resized_image: A 3D tensor of shape [new_height, new_width, channels],
where the image has been resized (with bilinear interpolation) so that
min(new_height, new_width) == min_dimension or
max(new_height, new_width) == max_dimension.
resized_masks: If masks is not None, also outputs masks. A 3D tensor of
shape [num_instances, new_height, new_width].
resized_image_shape: A 1D tensor of shape [3] containing shape of the
resized image.
Raises:
ValueError: if the image is not a 3D tensor.
"""
if
len
(
image
.
get_shape
())
!=
3
:
raise
ValueError
(
'Image should be 3D tensor'
)
def
_resize_landscape_image
(
image
):
# resize a landscape image
return
tf
.
image
.
resize
(
image
,
tf
.
stack
([
min_dimension
,
max_dimension
]),
method
=
method
,
preserve_aspect_ratio
=
True
)
def
_resize_portrait_image
(
image
):
# resize a portrait image
return
tf
.
image
.
resize
(
image
,
tf
.
stack
([
max_dimension
,
min_dimension
]),
method
=
method
,
preserve_aspect_ratio
=
True
)
with
tf
.
name_scope
(
'ResizeToRange'
):
if
image
.
get_shape
().
is_fully_defined
():
if
image
.
get_shape
()[
0
]
<
image
.
get_shape
()[
1
]:
new_image
=
_resize_landscape_image
(
image
)
else
:
new_image
=
_resize_portrait_image
(
image
)
new_size
=
tf
.
constant
(
new_image
.
get_shape
().
as_list
())
else
:
new_image
=
tf
.
cond
(
tf
.
less
(
tf
.
shape
(
image
)[
0
],
tf
.
shape
(
image
)[
1
]),
lambda
:
_resize_landscape_image
(
image
),
lambda
:
_resize_portrait_image
(
image
))
new_size
=
tf
.
shape
(
new_image
)
if
pad_to_max_dimension
:
channels
=
tf
.
unstack
(
new_image
,
axis
=
2
)
if
len
(
channels
)
!=
len
(
per_channel_pad_value
):
raise
ValueError
(
'Number of channels must be equal to the length of '
'per-channel pad value.'
)
new_image
=
tf
.
stack
(
[
tf
.
pad
(
# pylint: disable=g-complex-comprehension
channels
[
i
],
[[
0
,
max_dimension
-
new_size
[
0
]],
[
0
,
max_dimension
-
new_size
[
1
]]],
constant_values
=
per_channel_pad_value
[
i
])
for
i
in
range
(
len
(
channels
))
],
axis
=
2
)
new_image
.
set_shape
([
max_dimension
,
max_dimension
,
len
(
channels
)])
result
=
[
new_image
,
new_size
]
if
masks
is
not
None
:
new_masks
=
tf
.
expand_dims
(
masks
,
3
)
new_masks
=
tf
.
image
.
resize
(
new_masks
,
new_size
[:
-
1
],
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
if
pad_to_max_dimension
:
new_masks
=
tf
.
image
.
pad_to_bounding_box
(
new_masks
,
0
,
0
,
max_dimension
,
max_dimension
)
new_masks
=
tf
.
squeeze
(
new_masks
,
3
)
result
.
append
(
new_masks
)
return
result
def
_augment_only_rgb_channels
(
image
,
augment_function
):
"""Augments only the RGB slice of an image with additional channels."""
rgb_slice
=
image
[:,
:,
:
3
]
augmented_rgb_slice
=
augment_function
(
rgb_slice
)
image
=
tf
.
concat
([
augmented_rgb_slice
,
image
[:,
:,
3
:]],
-
1
)
return
image
def
random_adjust_brightness
(
image
,
max_delta
=
0.2
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly adjusts brightness.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: how much to change the brightness. A value between [0, 1).
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with
tf
.
name_scope
(
'RandomAdjustBrightness'
):
generator_func
=
functools
.
partial
(
tf
.
random
.
uniform
,
[],
-
max_delta
,
max_delta
,
seed
=
seed
)
delta
=
_get_or_create_preprocess_rand_vars
(
generator_func
,
'adjust_brightness'
,
preprocess_vars_cache
)
def
_adjust_brightness
(
image
):
image
=
tf
.
image
.
adjust_brightness
(
image
/
255
,
delta
)
*
255
image
=
tf
.
clip_by_value
(
image
,
clip_value_min
=
0.0
,
clip_value_max
=
255.0
)
return
image
image
=
_augment_only_rgb_channels
(
image
,
_adjust_brightness
)
return
image
def
random_adjust_contrast
(
image
,
min_delta
=
0.8
,
max_delta
=
1.25
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly adjusts contrast.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the contrast. Contrast will change with a
value between min_delta and max_delta. This value will be
multiplied to the current contrast of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with
tf
.
name_scope
(
'RandomAdjustContrast'
):
generator_func
=
functools
.
partial
(
tf
.
random
.
uniform
,
[],
min_delta
,
max_delta
,
seed
=
seed
)
contrast_factor
=
_get_or_create_preprocess_rand_vars
(
generator_func
,
'adjust_contrast'
,
preprocess_vars_cache
)
def
_adjust_contrast
(
image
):
image
=
tf
.
image
.
adjust_contrast
(
image
/
255
,
contrast_factor
)
*
255
image
=
tf
.
clip_by_value
(
image
,
clip_value_min
=
0.0
,
clip_value_max
=
255.0
)
return
image
image
=
_augment_only_rgb_channels
(
image
,
_adjust_contrast
)
return
image
def
random_adjust_hue
(
image
,
max_delta
=
0.02
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly adjusts hue.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
max_delta: change hue randomly with a value between 0 and max_delta.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with
tf
.
name_scope
(
'RandomAdjustHue'
):
generator_func
=
functools
.
partial
(
tf
.
random
.
uniform
,
[],
-
max_delta
,
max_delta
,
seed
=
seed
)
delta
=
_get_or_create_preprocess_rand_vars
(
generator_func
,
'adjust_hue'
,
preprocess_vars_cache
)
def
_adjust_hue
(
image
):
image
=
tf
.
image
.
adjust_hue
(
image
/
255
,
delta
)
*
255
image
=
tf
.
clip_by_value
(
image
,
clip_value_min
=
0.0
,
clip_value_max
=
255.0
)
return
image
image
=
_augment_only_rgb_channels
(
image
,
_adjust_hue
)
return
image
def
random_adjust_saturation
(
image
,
min_delta
=
0.8
,
max_delta
=
1.25
,
seed
=
None
,
preprocess_vars_cache
=
None
):
"""Randomly adjusts saturation.
Makes sure the output image is still between 0 and 255.
Args:
image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
with pixel values varying between [0, 255].
min_delta: see max_delta.
max_delta: how much to change the saturation. Saturation will change with a
value between min_delta and max_delta. This value will be
multiplied to the current saturation of the image.
seed: random seed.
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
function is called multiple times with the same
non-null cache, it will perform deterministically.
Returns:
image: image which is the same shape as input image.
"""
with
tf
.
name_scope
(
'RandomAdjustSaturation'
):
generator_func
=
functools
.
partial
(
tf
.
random
.
uniform
,
[],
min_delta
,
max_delta
,
seed
=
seed
)
saturation_factor
=
_get_or_create_preprocess_rand_vars
(
generator_func
,
'adjust_saturation'
,
preprocess_vars_cache
)
def
_adjust_saturation
(
image
):
image
=
tf
.
image
.
adjust_saturation
(
image
/
255
,
saturation_factor
)
*
255
image
=
tf
.
clip_by_value
(
image
,
clip_value_min
=
0.0
,
clip_value_max
=
255.0
)
return
image
image
=
_augment_only_rgb_channels
(
image
,
_adjust_saturation
)
return
image
official/vision/beta/projects/centernet/ops/target_assigner.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate targets (center, scale, offsets,...) for centernet."""
from
typing
import
Dict
,
List
import
tensorflow
as
tf
from
official.vision.beta.ops
import
sampling_ops
def
smallest_positive_root
(
a
,
b
,
c
):
"""Returns the smallest positive root of a quadratic equation."""
discriminant
=
tf
.
sqrt
(
b
**
2
-
4
*
a
*
c
)
return
(
-
b
+
discriminant
)
/
(
2.0
)
@
tf
.
function
def
cartesian_product
(
*
tensors
,
repeat
:
int
=
1
)
->
tf
.
Tensor
:
"""Equivalent of itertools.product except for TensorFlow tensors.
Example:
cartesian_product(tf.range(3), tf.range(4))
array([[0, 0],
[0, 1],
[0, 2],
[0, 3],
[1, 0],
[1, 1],
[1, 2],
[1, 3],
[2, 0],
[2, 1],
[2, 2],
[2, 3]], dtype=int32)>
Args:
*tensors: a list of 1D tensors to compute the product of
repeat: an `int` number of times to repeat the tensors
Returns:
An nD tensor where n is the number of tensors
"""
tensors
=
tensors
*
repeat
return
tf
.
reshape
(
tf
.
transpose
(
tf
.
stack
(
tf
.
meshgrid
(
*
tensors
,
indexing
=
'ij'
)),
[
*
[
i
+
1
for
i
in
range
(
len
(
tensors
))],
0
]),
(
-
1
,
len
(
tensors
)))
def
image_shape_to_grids
(
height
:
int
,
width
:
int
):
"""Computes xy-grids given the shape of the image.
Args:
height: The height of the image.
width: The width of the image.
Returns:
A tuple of two tensors:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
"""
out_height
=
tf
.
cast
(
height
,
tf
.
float32
)
out_width
=
tf
.
cast
(
width
,
tf
.
float32
)
x_range
=
tf
.
range
(
out_width
,
dtype
=
tf
.
float32
)
y_range
=
tf
.
range
(
out_height
,
dtype
=
tf
.
float32
)
x_grid
,
y_grid
=
tf
.
meshgrid
(
x_range
,
y_range
,
indexing
=
'xy'
)
return
(
y_grid
,
x_grid
)
def
max_distance_for_overlap
(
height
,
width
,
min_iou
):
"""Computes how far apart bbox corners can lie while maintaining the iou.
Given a bounding box size, this function returns a lower bound on how far
apart the corners of another box can lie while still maintaining the given
IoU. The implementation is based on the `gaussian_radius` function in the
Objects as Points github repo: https://github.com/xingyizhou/CenterNet
Args:
height: A 1-D float Tensor representing height of the ground truth boxes.
width: A 1-D float Tensor representing width of the ground truth boxes.
min_iou: A float representing the minimum IoU desired.
Returns:
distance: A 1-D Tensor of distances, of the same length as the input
height and width tensors.
"""
# Given that the detected box is displaced at a distance `d`, the exact
# IoU value will depend on the angle at which each corner is displaced.
# We simplify our computation by assuming that each corner is displaced by
# a distance `d` in both x and y direction. This gives us a lower IoU than
# what is actually realizable and ensures that any box with corners less
# than `d` distance apart will always have an IoU greater than or equal
# to `min_iou`
# The following 3 cases can be worked on geometrically and come down to
# solving a quadratic inequality. In each case, to ensure `min_iou` we use
# the smallest positive root of the equation.
# Case where detected box is offset from ground truth and no box completely
# contains the other.
distance_detection_offset
=
smallest_positive_root
(
a
=
1
,
b
=-
(
height
+
width
),
c
=
width
*
height
*
((
1
-
min_iou
)
/
(
1
+
min_iou
))
)
# Case where detection is smaller than ground truth and completely contained
# in it.
distance_detection_in_gt
=
smallest_positive_root
(
a
=
4
,
b
=-
2
*
(
height
+
width
),
c
=
(
1
-
min_iou
)
*
width
*
height
)
# Case where ground truth is smaller than detection and completely contained
# in it.
distance_gt_in_detection
=
smallest_positive_root
(
a
=
4
*
min_iou
,
b
=
(
2
*
min_iou
)
*
(
width
+
height
),
c
=
(
min_iou
-
1
)
*
width
*
height
)
return
tf
.
reduce_min
([
distance_detection_offset
,
distance_gt_in_detection
,
distance_detection_in_gt
],
axis
=
0
)
def
compute_std_dev_from_box_size
(
boxes_height
,
boxes_width
,
min_overlap
):
"""Computes the standard deviation of the Gaussian kernel from box size.
Args:
boxes_height: A 1D tensor with shape [num_instances] representing the height
of each box.
boxes_width: A 1D tensor with shape [num_instances] representing the width
of each box.
min_overlap: The minimum IOU overlap that boxes need to have to not be
penalized.
Returns:
A 1D tensor with shape [num_instances] representing the computed Gaussian
sigma for each of the box.
"""
# We are dividing by 3 so that points closer than the computed
# distance have a >99% CDF.
sigma
=
max_distance_for_overlap
(
boxes_height
,
boxes_width
,
min_overlap
)
sigma
=
(
2
*
tf
.
math
.
maximum
(
tf
.
math
.
floor
(
sigma
),
0.0
)
+
1
)
/
6.0
return
sigma
@
tf
.
function
def
assign_center_targets
(
out_height
:
int
,
out_width
:
int
,
y_center
:
tf
.
Tensor
,
x_center
:
tf
.
Tensor
,
boxes_height
:
tf
.
Tensor
,
boxes_width
:
tf
.
Tensor
,
channel_onehot
:
tf
.
Tensor
,
gaussian_iou
:
float
):
"""Computes the object center heatmap target based on ODAPI implementation.
Args:
out_height: int, height of output to the model. This is used to
determine the height of the output.
out_width: int, width of the output to the model. This is used to
determine the width of the output.
y_center: A 1D tensor with shape [num_instances] representing the
y-coordinates of the instances in the output space coordinates.
x_center: A 1D tensor with shape [num_instances] representing the
x-coordinates of the instances in the output space coordinates.
boxes_height: A 1D tensor with shape [num_instances] representing the height
of each box.
boxes_width: A 1D tensor with shape [num_instances] representing the width
of each box.
channel_onehot: A 2D tensor with shape [num_instances, num_channels]
representing the one-hot encoded channel labels for each point.
gaussian_iou: The minimum IOU overlap that boxes need to have to not be
penalized.
Returns:
heatmap: A Tensor of size [output_height, output_width,
num_classes] representing the per class center heatmap. output_height
and output_width are computed by dividing the input height and width by
the stride specified during initialization.
"""
(
y_grid
,
x_grid
)
=
image_shape_to_grids
(
out_height
,
out_width
)
sigma
=
compute_std_dev_from_box_size
(
boxes_height
,
boxes_width
,
gaussian_iou
)
num_instances
,
num_channels
=
(
sampling_ops
.
combined_static_and_dynamic_shape
(
channel_onehot
))
x_grid
=
tf
.
expand_dims
(
x_grid
,
2
)
y_grid
=
tf
.
expand_dims
(
y_grid
,
2
)
# The raw center coordinates in the output space.
x_diff
=
x_grid
-
tf
.
math
.
floor
(
x_center
)
y_diff
=
y_grid
-
tf
.
math
.
floor
(
y_center
)
squared_distance
=
x_diff
**
2
+
y_diff
**
2
gaussian_map
=
tf
.
exp
(
-
squared_distance
/
(
2
*
sigma
*
sigma
))
reshaped_gaussian_map
=
tf
.
expand_dims
(
gaussian_map
,
axis
=-
1
)
reshaped_channel_onehot
=
tf
.
reshape
(
channel_onehot
,
(
1
,
1
,
num_instances
,
num_channels
))
gaussian_per_box_per_class_map
=
(
reshaped_gaussian_map
*
reshaped_channel_onehot
)
# Take maximum along the "instance" dimension so that all per-instance
# heatmaps of the same class are merged together.
heatmap
=
tf
.
reduce_max
(
gaussian_per_box_per_class_map
,
axis
=
2
)
# Maximum of an empty tensor is -inf, the following is to avoid that.
heatmap
=
tf
.
maximum
(
heatmap
,
0
)
return
tf
.
stop_gradient
(
heatmap
)
def
assign_centernet_targets
(
labels
:
Dict
[
str
,
tf
.
Tensor
],
output_size
:
List
[
int
],
input_size
:
List
[
int
],
num_classes
:
int
=
90
,
max_num_instances
:
int
=
128
,
gaussian_iou
:
float
=
0.7
,
class_offset
:
int
=
0
,
dtype
=
'float32'
):
"""Generates the ground truth labels for centernet.
Ground truth labels are generated by splatting gaussians on heatmaps for
corners and centers. Regressed features (offsets and sizes) are also
generated.
Args:
labels: A dictionary of COCO ground truth labels with at minimum the
following fields:
"bbox" A `Tensor` of shape [max_num_instances, 4], where the
last dimension corresponds to the top left x, top left y,
bottom right x, and bottom left y coordinates of the bounding box
"classes" A `Tensor` of shape [max_num_instances] that contains
the class of each box, given in the same order as the boxes
"num_detections" A `Tensor` or int that gives the number of objects
output_size: A `list` of length 2 containing the desired output height
and width of the heatmaps
input_size: A `list` of length 2 the expected input height and width of
the image
num_classes: A `Tensor` or `int` for the number of classes.
max_num_instances: An `int` for maximum number of instances in an image.
gaussian_iou: A `float` number for the minimum desired IOU used when
determining the gaussian radius of center locations in the heatmap.
class_offset: A `int` for subtracting a value from the ground truth classes
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
Returns:
Dictionary of labels with the following fields:
'ct_heatmaps': Tensor of shape [output_h, output_w, num_classes],
heatmap with splatted gaussians centered at the positions and channels
corresponding to the center location and class of the object
'ct_offset': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the x-offset and y-offset of the center of
an object. All other entires are 0
'size': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the width and height of an object. All
other entires are 0
'box_mask': `Tensor` of shape [max_num_instances], where the first
num_boxes entries are 1. All other entires are 0
'box_indices': `Tensor` of shape [max_num_instances, 2], where the first
num_boxes entries contain the y-center and x-center of a valid box.
These are used to extract the regressed box features from the
prediction when computing the loss
Raises:
Exception: if datatype is not supported.
"""
if
dtype
==
'float16'
:
dtype
=
tf
.
float16
elif
dtype
==
'bfloat16'
:
dtype
=
tf
.
bfloat16
elif
dtype
==
'float32'
:
dtype
=
tf
.
float32
else
:
raise
Exception
(
'Unsupported datatype used in ground truth builder only '
'{float16, bfloat16, or float32}'
)
# Get relevant bounding box and class information from labels
# only keep the first num_objects boxes and classes
num_objects
=
labels
[
'groundtruths'
][
'num_detections'
]
# shape of labels['boxes'] is [max_num_instances, 4]
# [ymin, xmin, ymax, xmax]
boxes
=
tf
.
cast
(
labels
[
'boxes'
],
dtype
)
# shape of labels['classes'] is [max_num_instances, ]
classes
=
tf
.
cast
(
labels
[
'classes'
]
-
class_offset
,
dtype
)
# Compute scaling factors for center/corner positions on heatmap
# input_size = tf.cast(input_size, dtype)
# output_size = tf.cast(output_size, dtype)
input_h
,
input_w
=
input_size
[
0
],
input_size
[
1
]
output_h
,
output_w
=
output_size
[
0
],
output_size
[
1
]
width_ratio
=
output_w
/
input_w
height_ratio
=
output_h
/
input_h
# Original box coordinates
# [max_num_instances, ]
ytl
,
ybr
=
boxes
[...,
0
],
boxes
[...,
2
]
xtl
,
xbr
=
boxes
[...,
1
],
boxes
[...,
3
]
yct
=
(
ytl
+
ybr
)
/
2
xct
=
(
xtl
+
xbr
)
/
2
# Scaled box coordinates (could be floating point)
# [max_num_instances, ]
scale_xct
=
xct
*
width_ratio
scale_yct
=
yct
*
height_ratio
# Floor the scaled box coordinates to be placed on heatmaps
# [max_num_instances, ]
scale_xct_floor
=
tf
.
math
.
floor
(
scale_xct
)
scale_yct_floor
=
tf
.
math
.
floor
(
scale_yct
)
# Offset computations to make up for discretization error
# used for offset maps
# [max_num_instances, 2]
ct_offset_values
=
tf
.
stack
([
scale_yct
-
scale_yct_floor
,
scale_xct
-
scale_xct_floor
],
axis
=-
1
)
# Get the scaled box dimensions for computing the gaussian radius
# [max_num_instances, ]
box_widths
=
boxes
[...,
3
]
-
boxes
[...,
1
]
box_heights
=
boxes
[...,
2
]
-
boxes
[...,
0
]
box_widths
=
box_widths
*
width_ratio
box_heights
=
box_heights
*
height_ratio
# Used for size map
# [max_num_instances, 2]
box_heights_widths
=
tf
.
stack
([
box_heights
,
box_widths
],
axis
=-
1
)
# Center/corner heatmaps
# [output_h, output_w, num_classes]
ct_heatmap
=
tf
.
zeros
((
output_h
,
output_w
,
num_classes
),
dtype
)
# Maps for offset and size features for each instance of a box
# [max_num_instances, 2]
ct_offset
=
tf
.
zeros
((
max_num_instances
,
2
),
dtype
)
# [max_num_instances, 2]
size
=
tf
.
zeros
((
max_num_instances
,
2
),
dtype
)
# Mask for valid box instances and their center indices in the heatmap
# [max_num_instances, ]
box_mask
=
tf
.
zeros
((
max_num_instances
,),
tf
.
int32
)
# [max_num_instances, 2]
box_indices
=
tf
.
zeros
((
max_num_instances
,
2
),
tf
.
int32
)
if
num_objects
>
0
:
# Need to gaussians around the centers and corners of the objects
ct_heatmap
=
assign_center_targets
(
out_height
=
output_h
,
out_width
=
output_w
,
y_center
=
scale_yct_floor
[:
num_objects
],
x_center
=
scale_xct_floor
[:
num_objects
],
boxes_height
=
box_heights
[:
num_objects
],
boxes_width
=
box_widths
[:
num_objects
],
channel_onehot
=
tf
.
one_hot
(
tf
.
cast
(
classes
[:
num_objects
],
tf
.
int32
),
num_classes
,
off_value
=
0.
),
gaussian_iou
=
gaussian_iou
)
# Indices used to update offsets and sizes for valid box instances
update_indices
=
cartesian_product
(
tf
.
range
(
max_num_instances
),
tf
.
range
(
2
))
# [max_num_instances, 2, 2]
update_indices
=
tf
.
reshape
(
update_indices
,
shape
=
[
max_num_instances
,
2
,
2
])
# Write the offsets of each box instance
ct_offset
=
tf
.
tensor_scatter_nd_update
(
ct_offset
,
update_indices
,
ct_offset_values
)
# Write the size of each bounding box
size
=
tf
.
tensor_scatter_nd_update
(
size
,
update_indices
,
box_heights_widths
)
# Initially the mask is zeros, so now we unmask each valid box instance
box_mask
=
tf
.
where
(
tf
.
range
(
max_num_instances
)
<
num_objects
,
1
,
0
)
# Write the y and x coordinate of each box center in the heatmap
box_index_values
=
tf
.
cast
(
tf
.
stack
([
scale_yct_floor
,
scale_xct_floor
],
axis
=-
1
),
dtype
=
tf
.
int32
)
box_indices
=
tf
.
tensor_scatter_nd_update
(
box_indices
,
update_indices
,
box_index_values
)
ct_labels
=
{
# [output_h, output_w, num_classes]
'ct_heatmaps'
:
ct_heatmap
,
# [max_num_instances, 2]
'ct_offset'
:
ct_offset
,
# [max_num_instances, 2]
'size'
:
size
,
# [max_num_instances, ]
'box_mask'
:
box_mask
,
# [max_num_instances, 2]
'box_indices'
:
box_indices
}
return
ct_labels
official/vision/beta/projects/centernet/ops/target_assigner_test.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for targets generations of centernet."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.centernet.ops
import
target_assigner
class
TargetAssignerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
check_labels_correct
(
self
,
boxes
,
classes
,
output_size
,
input_size
):
max_num_instances
=
128
num_detections
=
len
(
boxes
)
boxes
=
tf
.
constant
(
boxes
,
dtype
=
tf
.
float32
)
classes
=
tf
.
constant
(
classes
,
dtype
=
tf
.
float32
)
boxes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
max_num_instances
,
0
)
classes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
max_num_instances
,
0
)
# pylint: disable=g-long-lambda
labels
=
target_assigner
.
assign_centernet_targets
(
labels
=
{
'boxes'
:
boxes
,
'classes'
:
classes
,
'groundtruths'
:
{
'num_detections'
:
num_detections
,
}
},
output_size
=
output_size
,
input_size
=
input_size
)
ct_heatmaps
=
labels
[
'ct_heatmaps'
]
ct_offset
=
labels
[
'ct_offset'
]
size
=
labels
[
'size'
]
box_mask
=
labels
[
'box_mask'
]
box_indices
=
labels
[
'box_indices'
]
boxes
=
tf
.
cast
(
boxes
,
tf
.
float32
)
classes
=
tf
.
cast
(
classes
,
tf
.
float32
)
height_ratio
=
output_size
[
0
]
/
input_size
[
0
]
width_ratio
=
output_size
[
1
]
/
input_size
[
1
]
# Shape checks
self
.
assertEqual
(
ct_heatmaps
.
shape
,
(
output_size
[
0
],
output_size
[
1
],
90
))
self
.
assertEqual
(
ct_offset
.
shape
,
(
max_num_instances
,
2
))
self
.
assertEqual
(
size
.
shape
,
(
max_num_instances
,
2
))
self
.
assertEqual
(
box_mask
.
shape
,
(
max_num_instances
,))
self
.
assertEqual
(
box_indices
.
shape
,
(
max_num_instances
,
2
))
self
.
assertAllInRange
(
ct_heatmaps
,
0
,
1
)
for
i
in
range
(
len
(
boxes
)):
# Check sizes
self
.
assertAllEqual
(
size
[
i
],
[(
boxes
[
i
][
2
]
-
boxes
[
i
][
0
])
*
height_ratio
,
(
boxes
[
i
][
3
]
-
boxes
[
i
][
1
])
*
width_ratio
,
])
# Check box indices
y
=
tf
.
math
.
floor
((
boxes
[
i
][
0
]
+
boxes
[
i
][
2
])
/
2
*
height_ratio
)
x
=
tf
.
math
.
floor
((
boxes
[
i
][
1
]
+
boxes
[
i
][
3
])
/
2
*
width_ratio
)
self
.
assertAllEqual
(
box_indices
[
i
],
[
y
,
x
])
# check offsets
true_y
=
(
boxes
[
i
][
0
]
+
boxes
[
i
][
2
])
/
2
*
height_ratio
true_x
=
(
boxes
[
i
][
1
]
+
boxes
[
i
][
3
])
/
2
*
width_ratio
self
.
assertAllEqual
(
ct_offset
[
i
],
[
true_y
-
y
,
true_x
-
x
])
for
i
in
range
(
len
(
boxes
),
max_num_instances
):
# Make sure rest are zero
self
.
assertAllEqual
(
size
[
i
],
[
0
,
0
])
self
.
assertAllEqual
(
box_indices
[
i
],
[
0
,
0
])
self
.
assertAllEqual
(
ct_offset
[
i
],
[
0
,
0
])
# Check mask indices
self
.
assertAllEqual
(
tf
.
cast
(
box_mask
[
3
:],
tf
.
int32
),
tf
.
repeat
(
0
,
repeats
=
max_num_instances
-
3
))
self
.
assertAllEqual
(
tf
.
cast
(
box_mask
[:
3
],
tf
.
int32
),
tf
.
repeat
(
1
,
repeats
=
3
))
def
test_generate_targets_no_scale
(
self
):
boxes
=
[
(
10
,
300
,
15
,
370
),
(
100
,
300
,
150
,
370
),
(
15
,
100
,
200
,
170
),
]
classes
=
(
1
,
2
,
3
)
sizes
=
[
512
,
512
]
self
.
check_labels_correct
(
boxes
=
boxes
,
classes
=
classes
,
output_size
=
sizes
,
input_size
=
sizes
)
def
test_generate_targets_stride_4
(
self
):
boxes
=
[
(
10
,
300
,
15
,
370
),
(
100
,
300
,
150
,
370
),
(
15
,
100
,
200
,
170
),
]
classes
=
(
1
,
2
,
3
)
output_size
=
[
128
,
128
]
input_size
=
[
512
,
512
]
self
.
check_labels_correct
(
boxes
=
boxes
,
classes
=
classes
,
output_size
=
output_size
,
input_size
=
input_size
)
def
test_generate_targets_stride_8
(
self
):
boxes
=
[
(
10
,
300
,
15
,
370
),
(
100
,
300
,
150
,
370
),
(
15
,
100
,
200
,
170
),
]
classes
=
(
1
,
2
,
3
)
output_size
=
[
128
,
128
]
input_size
=
[
1024
,
1024
]
self
.
check_labels_correct
(
boxes
=
boxes
,
classes
=
classes
,
output_size
=
output_size
,
input_size
=
input_size
)
def
test_batch_generate_targets
(
self
):
input_size
=
[
512
,
512
]
output_size
=
[
128
,
128
]
max_num_instances
=
128
boxes
=
tf
.
constant
([
(
10
,
300
,
15
,
370
),
# center (y, x) = (12, 335)
(
100
,
300
,
150
,
370
),
# center (y, x) = (125, 335)
(
15
,
100
,
200
,
170
),
# center (y, x) = (107, 135)
],
dtype
=
tf
.
float32
)
classes
=
tf
.
constant
((
1
,
1
,
1
),
dtype
=
tf
.
float32
)
boxes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
max_num_instances
,
0
)
classes
=
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
max_num_instances
,
0
)
boxes
=
tf
.
stack
([
boxes
,
boxes
],
axis
=
0
)
classes
=
tf
.
stack
([
classes
,
classes
],
axis
=
0
)
# pylint: disable=g-long-lambda
labels
=
tf
.
map_fn
(
fn
=
lambda
x
:
target_assigner
.
assign_centernet_targets
(
labels
=
x
,
output_size
=
output_size
,
input_size
=
input_size
),
elems
=
{
'boxes'
:
boxes
,
'classes'
:
classes
,
'groundtruths'
:
{
'num_detections'
:
tf
.
constant
([
3
,
3
]),
}
},
dtype
=
{
'ct_heatmaps'
:
tf
.
float32
,
'ct_offset'
:
tf
.
float32
,
'size'
:
tf
.
float32
,
'box_mask'
:
tf
.
int32
,
'box_indices'
:
tf
.
int32
}
)
ct_heatmaps
=
labels
[
'ct_heatmaps'
]
ct_offset
=
labels
[
'ct_offset'
]
size
=
labels
[
'size'
]
box_mask
=
labels
[
'box_mask'
]
box_indices
=
labels
[
'box_indices'
]
self
.
assertEqual
(
ct_heatmaps
.
shape
,
(
2
,
output_size
[
0
],
output_size
[
1
],
90
))
self
.
assertEqual
(
ct_offset
.
shape
,
(
2
,
max_num_instances
,
2
))
self
.
assertEqual
(
size
.
shape
,
(
2
,
max_num_instances
,
2
))
self
.
assertEqual
(
box_mask
.
shape
,
(
2
,
max_num_instances
))
self
.
assertEqual
(
box_indices
.
shape
,
(
2
,
max_num_instances
,
2
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/centernet/tasks/centernet.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Centernet task definition."""
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.centernet.configs
import
centernet
as
exp_cfg
from
official.vision.beta.projects.centernet.dataloaders
import
centernet_input
from
official.vision.beta.projects.centernet.losses
import
centernet_losses
from
official.vision.beta.projects.centernet.modeling
import
centernet_model
from
official.vision.beta.projects.centernet.modeling.heads
import
centernet_head
from
official.vision.beta.projects.centernet.modeling.layers
import
detection_generator
from
official.vision.beta.projects.centernet.ops
import
loss_ops
from
official.vision.beta.projects.centernet.ops
import
target_assigner
@
task_factory
.
register_task_cls
(
exp_cfg
.
CenterNetTask
)
class
CenterNetTask
(
base_task
.
Task
):
"""Task definition for centernet."""
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
centernet_input
.
CenterNetParser
(
output_height
=
self
.
task_config
.
model
.
input_size
[
0
],
output_width
=
self
.
task_config
.
model
.
input_size
[
1
],
max_num_instances
=
self
.
task_config
.
model
.
max_num_instances
,
bgr_ordering
=
params
.
parser
.
bgr_ordering
,
channel_means
=
params
.
parser
.
channel_means
,
channel_stds
=
params
.
parser
.
channel_stds
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
aug_rand_hue
=
params
.
parser
.
aug_rand_hue
,
aug_rand_brightness
=
params
.
parser
.
aug_rand_brightness
,
aug_rand_contrast
=
params
.
parser
.
aug_rand_contrast
,
aug_rand_saturation
=
params
.
parser
.
aug_rand_saturation
,
odapi_augmentation
=
params
.
parser
.
odapi_augmentation
,
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_model
(
self
):
"""get an instance of CenterNet."""
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
backbone
=
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
,
l2_regularizer
=
l2_regularizer
)
task_outputs
=
self
.
task_config
.
get_output_length_dict
()
head_config
=
model_config
.
head
head
=
centernet_head
.
CenterNetHead
(
input_specs
=
backbone
.
output_specs
,
task_outputs
=
task_outputs
,
input_levels
=
head_config
.
input_levels
,
heatmap_bias
=
head_config
.
heatmap_bias
)
# output_specs is a dict
backbone_output_spec
=
backbone
.
output_specs
[
head_config
.
input_levels
[
-
1
]]
if
len
(
backbone_output_spec
)
==
4
:
bb_output_height
=
backbone_output_spec
[
1
]
elif
len
(
backbone_output_spec
)
==
3
:
bb_output_height
=
backbone_output_spec
[
0
]
else
:
raise
ValueError
self
.
_net_down_scale
=
int
(
model_config
.
input_size
[
0
]
/
bb_output_height
)
dg_config
=
model_config
.
detection_generator
detect_generator_obj
=
detection_generator
.
CenterNetDetectionGenerator
(
max_detections
=
dg_config
.
max_detections
,
peak_error
=
dg_config
.
peak_error
,
peak_extract_kernel_size
=
dg_config
.
peak_extract_kernel_size
,
class_offset
=
dg_config
.
class_offset
,
net_down_scale
=
self
.
_net_down_scale
,
input_image_dims
=
model_config
.
input_size
[
0
],
use_nms
=
dg_config
.
use_nms
,
nms_pre_thresh
=
dg_config
.
nms_pre_thresh
,
nms_thresh
=
dg_config
.
nms_thresh
)
model
=
centernet_model
.
CenterNetModel
(
backbone
=
backbone
,
head
=
head
,
detection_generator
=
detect_generator_obj
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
# Restoring checkpoint.
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
"""Build losses."""
input_size
=
self
.
task_config
.
model
.
input_size
[
0
:
2
]
output_size
=
outputs
[
'ct_heatmaps'
][
0
].
get_shape
().
as_list
()[
1
:
3
]
gt_label
=
tf
.
map_fn
(
# pylint: disable=g-long-lambda
fn
=
lambda
x
:
target_assigner
.
assign_centernet_targets
(
labels
=
x
,
input_size
=
input_size
,
output_size
=
output_size
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
max_num_instances
=
self
.
task_config
.
model
.
max_num_instances
,
gaussian_iou
=
self
.
task_config
.
losses
.
gaussian_iou
,
class_offset
=
self
.
task_config
.
losses
.
class_offset
),
elems
=
labels
,
fn_output_signature
=
{
'ct_heatmaps'
:
tf
.
TensorSpec
(
shape
=
[
output_size
[
0
],
output_size
[
1
],
self
.
task_config
.
model
.
num_classes
],
dtype
=
tf
.
float32
),
'ct_offset'
:
tf
.
TensorSpec
(
shape
=
[
self
.
task_config
.
model
.
max_num_instances
,
2
],
dtype
=
tf
.
float32
),
'size'
:
tf
.
TensorSpec
(
shape
=
[
self
.
task_config
.
model
.
max_num_instances
,
2
],
dtype
=
tf
.
float32
),
'box_mask'
:
tf
.
TensorSpec
(
shape
=
[
self
.
task_config
.
model
.
max_num_instances
],
dtype
=
tf
.
int32
),
'box_indices'
:
tf
.
TensorSpec
(
shape
=
[
self
.
task_config
.
model
.
max_num_instances
,
2
],
dtype
=
tf
.
int32
),
}
)
losses
=
{}
# Create loss functions
object_center_loss_fn
=
centernet_losses
.
PenaltyReducedLogisticFocalLoss
()
localization_loss_fn
=
centernet_losses
.
L1LocalizationLoss
()
# Set up box indices so that they have a batch element as well
box_indices
=
loss_ops
.
add_batch_to_indices
(
gt_label
[
'box_indices'
])
box_mask
=
tf
.
cast
(
gt_label
[
'box_mask'
],
dtype
=
tf
.
float32
)
num_boxes
=
tf
.
cast
(
loss_ops
.
get_num_instances_from_weights
(
gt_label
[
'box_mask'
]),
dtype
=
tf
.
float32
)
# Calculate center heatmap loss
output_unpad_image_shapes
=
tf
.
math
.
ceil
(
tf
.
cast
(
labels
[
'unpad_image_shapes'
],
tf
.
float32
)
/
self
.
_net_down_scale
)
valid_anchor_weights
=
loss_ops
.
get_valid_anchor_weights_in_flattened_image
(
output_unpad_image_shapes
,
output_size
[
0
],
output_size
[
1
])
valid_anchor_weights
=
tf
.
expand_dims
(
valid_anchor_weights
,
2
)
pred_ct_heatmap_list
=
outputs
[
'ct_heatmaps'
]
true_flattened_ct_heatmap
=
loss_ops
.
flatten_spatial_dimensions
(
gt_label
[
'ct_heatmaps'
])
true_flattened_ct_heatmap
=
tf
.
cast
(
true_flattened_ct_heatmap
,
tf
.
float32
)
total_center_loss
=
0.0
for
ct_heatmap
in
pred_ct_heatmap_list
:
pred_flattened_ct_heatmap
=
loss_ops
.
flatten_spatial_dimensions
(
ct_heatmap
)
pred_flattened_ct_heatmap
=
tf
.
cast
(
pred_flattened_ct_heatmap
,
tf
.
float32
)
total_center_loss
+=
object_center_loss_fn
(
target_tensor
=
true_flattened_ct_heatmap
,
prediction_tensor
=
pred_flattened_ct_heatmap
,
weights
=
valid_anchor_weights
)
center_loss
=
tf
.
reduce_sum
(
total_center_loss
)
/
float
(
len
(
pred_ct_heatmap_list
)
*
num_boxes
)
losses
[
'ct_loss'
]
=
center_loss
# Calculate scale loss
pred_scale_list
=
outputs
[
'ct_size'
]
true_scale
=
tf
.
cast
(
gt_label
[
'size'
],
tf
.
float32
)
total_scale_loss
=
0.0
for
scale_map
in
pred_scale_list
:
pred_scale
=
loss_ops
.
get_batch_predictions_from_indices
(
scale_map
,
box_indices
)
pred_scale
=
tf
.
cast
(
pred_scale
,
tf
.
float32
)
# Only apply loss for boxes that appear in the ground truth
total_scale_loss
+=
tf
.
reduce_sum
(
localization_loss_fn
(
target_tensor
=
true_scale
,
prediction_tensor
=
pred_scale
),
axis
=-
1
)
*
box_mask
scale_loss
=
tf
.
reduce_sum
(
total_scale_loss
)
/
float
(
len
(
pred_scale_list
)
*
num_boxes
)
losses
[
'scale_loss'
]
=
scale_loss
# Calculate offset loss
pred_offset_list
=
outputs
[
'ct_offset'
]
true_offset
=
tf
.
cast
(
gt_label
[
'ct_offset'
],
tf
.
float32
)
total_offset_loss
=
0.0
for
offset_map
in
pred_offset_list
:
pred_offset
=
loss_ops
.
get_batch_predictions_from_indices
(
offset_map
,
box_indices
)
pred_offset
=
tf
.
cast
(
pred_offset
,
tf
.
float32
)
# Only apply loss for boxes that appear in the ground truth
total_offset_loss
+=
tf
.
reduce_sum
(
localization_loss_fn
(
target_tensor
=
true_offset
,
prediction_tensor
=
pred_offset
),
axis
=-
1
)
*
box_mask
offset_loss
=
tf
.
reduce_sum
(
total_offset_loss
)
/
float
(
len
(
pred_offset_list
)
*
num_boxes
)
losses
[
'ct_offset_loss'
]
=
offset_loss
# Aggregate and finalize loss
loss_weights
=
self
.
task_config
.
losses
.
detection
total_loss
=
(
loss_weights
.
object_center_weight
*
center_loss
+
loss_weights
.
scale_weight
*
scale_loss
+
loss_weights
.
offset_weight
*
offset_loss
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
losses
[
'total_loss'
]
=
total_loss
return
losses
def
build_metrics
(
self
,
training
=
True
):
metrics
=
[]
metric_names
=
[
'total_loss'
,
'ct_loss'
,
'scale_loss'
,
'ct_offset_loss'
]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
if
not
training
:
if
(
self
.
task_config
.
validation_data
.
tfds_name
and
self
.
task_config
.
annotation_file
):
raise
ValueError
(
"Can't evaluate using annotation file when TFDS is used."
)
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
task_config
.
annotation_file
,
include_mask
=
False
,
per_category_metrics
=
self
.
task_config
.
per_category_metrics
)
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
losses
=
self
.
build_losses
(
outputs
[
'raw_output'
],
labels
)
scaled_loss
=
losses
[
'total_loss'
]
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
# compute the gradient
tvars
=
model
.
trainable_variables
gradients
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# get unscaled loss if the scaled loss was used
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
gradients
=
optimizer
.
get_unscaled_gradients
(
gradients
)
if
self
.
task_config
.
gradient_clip_norm
>
0.0
:
gradients
,
_
=
tf
.
clip_by_global_norm
(
gradients
,
self
.
task_config
.
gradient_clip_norm
)
optimizer
.
apply_gradients
(
list
(
zip
(
gradients
,
tvars
)))
logs
=
{
self
.
loss
:
losses
[
'total_loss'
]}
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
model
(
features
,
training
=
False
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
losses
=
self
.
build_losses
(
outputs
[
'raw_output'
],
labels
)
logs
=
{
self
.
loss
:
losses
[
'total_loss'
]}
coco_model_outputs
=
{
'detection_boxes'
:
outputs
[
'boxes'
],
'detection_scores'
:
outputs
[
'confidence'
],
'detection_classes'
:
outputs
[
'classes'
],
'num_detections'
:
outputs
[
'num_detections'
],
'source_id'
:
labels
[
'groundtruths'
][
'source_id'
],
'image_info'
:
labels
[
'image_info'
]
}
logs
.
update
({
self
.
coco_metric
.
name
:
(
labels
[
'groundtruths'
],
coco_model_outputs
)})
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
self
.
coco_metric
.
reset_states
()
state
=
self
.
coco_metric
self
.
coco_metric
.
update_state
(
step_outputs
[
self
.
coco_metric
.
name
][
0
],
step_outputs
[
self
.
coco_metric
.
name
][
1
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
return
self
.
coco_metric
.
result
()
official/vision/beta/projects/centernet/train.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision Centernet trainer."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta.projects.centernet.common
import
registry_imports
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/vision/beta/projects/centernet/utils/checkpoints/__init__.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/centernet/utils/checkpoints/config_classes.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layer config for parsing ODAPI checkpoint.
This file contains the layers (Config objects) that are used for parsing the
ODAPI checkpoint weights for CenterNet.
Currently, the parser is incomplete and has only been tested on
CenterNet Hourglass-104 512x512.
"""
import
abc
import
dataclasses
from
typing
import
Dict
,
Optional
import
numpy
as
np
import
tensorflow
as
tf
class
Config
(
abc
.
ABC
):
"""Base config class."""
def
get_weights
(
self
):
"""Generates the weights needed to be loaded into the layer."""
raise
NotImplementedError
def
load_weights
(
self
,
layer
:
tf
.
keras
.
layers
.
Layer
)
->
int
:
"""Assign weights to layer.
Given a layer, this function retrieves the weights for that layer in an
appropriate format and order, and loads them into the layer. Additionally,
the number of weights loaded are returned.
If the weights are in an incorrect format, a ValueError
will be raised by set_weights().
Args:
layer: A `tf.keras.layers.Layer`.
Returns:
"""
weights
=
self
.
get_weights
()
layer
.
set_weights
(
weights
)
n_weights
=
0
for
w
in
weights
:
n_weights
+=
w
.
size
return
n_weights
@
dataclasses
.
dataclass
class
Conv2DBNCFG
(
Config
):
"""Config class for Conv2DBN block."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
beta
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
gamma
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
moving_mean
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
moving_variance
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
__post_init__
(
self
):
conv_weights_dict
=
self
.
weights_dict
[
'conv'
]
norm_weights_dict
=
self
.
weights_dict
[
'norm'
]
self
.
weights
=
conv_weights_dict
[
'kernel'
]
self
.
beta
=
norm_weights_dict
[
'beta'
]
self
.
gamma
=
norm_weights_dict
[
'gamma'
]
self
.
moving_mean
=
norm_weights_dict
[
'moving_mean'
]
self
.
moving_variance
=
norm_weights_dict
[
'moving_variance'
]
def
get_weights
(
self
):
return
[
self
.
weights
,
self
.
gamma
,
self
.
beta
,
self
.
moving_mean
,
self
.
moving_variance
]
@
dataclasses
.
dataclass
class
ResidualBlockCFG
(
Config
):
"""Config class for Residual block."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
skip_weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
skip_beta
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
skip_gamma
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
skip_moving_mean
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
skip_moving_variance
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
norm_beta
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
norm_gamma
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
norm_moving_mean
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
norm_moving_variance
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_block_weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_block_beta
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_block_gamma
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_block_moving_mean
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_block_moving_variance
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
__post_init__
(
self
):
conv_weights_dict
=
self
.
weights_dict
[
'conv'
]
norm_weights_dict
=
self
.
weights_dict
[
'norm'
]
conv_block_weights_dict
=
self
.
weights_dict
[
'conv_block'
]
if
'skip'
in
self
.
weights_dict
:
skip_weights_dict
=
self
.
weights_dict
[
'skip'
]
self
.
skip_weights
=
skip_weights_dict
[
'conv'
][
'kernel'
]
self
.
skip_beta
=
skip_weights_dict
[
'norm'
][
'beta'
]
self
.
skip_gamma
=
skip_weights_dict
[
'norm'
][
'gamma'
]
self
.
skip_moving_mean
=
skip_weights_dict
[
'norm'
][
'moving_mean'
]
self
.
skip_moving_variance
=
skip_weights_dict
[
'norm'
][
'moving_variance'
]
self
.
conv_weights
=
conv_weights_dict
[
'kernel'
]
self
.
norm_beta
=
norm_weights_dict
[
'beta'
]
self
.
norm_gamma
=
norm_weights_dict
[
'gamma'
]
self
.
norm_moving_mean
=
norm_weights_dict
[
'moving_mean'
]
self
.
norm_moving_variance
=
norm_weights_dict
[
'moving_variance'
]
self
.
conv_block_weights
=
conv_block_weights_dict
[
'conv'
][
'kernel'
]
self
.
conv_block_beta
=
conv_block_weights_dict
[
'norm'
][
'beta'
]
self
.
conv_block_gamma
=
conv_block_weights_dict
[
'norm'
][
'gamma'
]
self
.
conv_block_moving_mean
=
conv_block_weights_dict
[
'norm'
][
'moving_mean'
]
self
.
conv_block_moving_variance
=
conv_block_weights_dict
[
'norm'
][
'moving_variance'
]
def
get_weights
(
self
):
weights
=
[
self
.
skip_weights
,
self
.
skip_gamma
,
self
.
skip_beta
,
self
.
conv_block_weights
,
self
.
conv_block_gamma
,
self
.
conv_block_beta
,
self
.
conv_weights
,
self
.
norm_gamma
,
self
.
norm_beta
,
self
.
skip_moving_mean
,
self
.
skip_moving_variance
,
self
.
conv_block_moving_mean
,
self
.
conv_block_moving_variance
,
self
.
norm_moving_mean
,
self
.
norm_moving_variance
,
]
weights
=
[
x
for
x
in
weights
if
x
is
not
None
]
return
weights
@
dataclasses
.
dataclass
class
HeadConvCFG
(
Config
):
"""Config class for HeadConv block."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_1_weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_1_bias
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_2_weights
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
conv_2_bias
:
Optional
[
np
.
ndarray
]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
__post_init__
(
self
):
conv_1_weights_dict
=
self
.
weights_dict
[
'layer_with_weights-0'
]
conv_2_weights_dict
=
self
.
weights_dict
[
'layer_with_weights-1'
]
self
.
conv_1_weights
=
conv_1_weights_dict
[
'kernel'
]
self
.
conv_1_bias
=
conv_1_weights_dict
[
'bias'
]
self
.
conv_2_weights
=
conv_2_weights_dict
[
'kernel'
]
self
.
conv_2_bias
=
conv_2_weights_dict
[
'bias'
]
def
get_weights
(
self
):
return
[
self
.
conv_1_weights
,
self
.
conv_1_bias
,
self
.
conv_2_weights
,
self
.
conv_2_bias
]
@
dataclasses
.
dataclass
class
HourglassCFG
(
Config
):
"""Config class for Hourglass block."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
is_last_stage
:
bool
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
__post_init__
(
self
):
self
.
is_last_stage
=
False
if
'inner_block'
in
self
.
weights_dict
else
True
def
get_weights
(
self
):
"""It is not used in this class."""
return
None
def
generate_block_weights
(
self
,
weights_dict
):
"""Convert weights dict to blocks structure."""
reps
=
len
(
weights_dict
.
keys
())
weights
=
[]
n_weights
=
0
for
i
in
range
(
reps
):
res_config
=
ResidualBlockCFG
(
weights_dict
=
weights_dict
[
str
(
i
)])
res_weights
=
res_config
.
get_weights
()
weights
+=
res_weights
for
w
in
res_weights
:
n_weights
+=
w
.
size
return
weights
,
n_weights
def
load_block_weights
(
self
,
layer
,
weight_dict
):
block_weights
,
n_weights
=
self
.
generate_block_weights
(
weight_dict
)
layer
.
set_weights
(
block_weights
)
return
n_weights
def
load_weights
(
self
,
layer
):
n_weights
=
0
if
not
self
.
is_last_stage
:
enc_dec_layers
=
[
layer
.
submodules
[
0
],
layer
.
submodules
[
1
],
layer
.
submodules
[
3
]
]
enc_dec_weight_dicts
=
[
self
.
weights_dict
[
'encoder_block1'
],
self
.
weights_dict
[
'encoder_block2'
],
self
.
weights_dict
[
'decoder_block'
]
]
for
l
,
weights_dict
in
zip
(
enc_dec_layers
,
enc_dec_weight_dicts
):
n_weights
+=
self
.
load_block_weights
(
l
,
weights_dict
)
if
len
(
self
.
weights_dict
[
'inner_block'
])
==
1
:
# still in an outer hourglass
inner_weights_dict
=
self
.
weights_dict
[
'inner_block'
][
'0'
]
else
:
# inner residual block chain
inner_weights_dict
=
self
.
weights_dict
[
'inner_block'
]
inner_hg_layer
=
layer
.
submodules
[
2
]
inner_hg_cfg
=
type
(
self
)(
weights_dict
=
inner_weights_dict
)
n_weights
+=
inner_hg_cfg
.
load_weights
(
inner_hg_layer
)
else
:
inner_layer
=
layer
.
submodules
[
0
]
n_weights
+=
self
.
load_block_weights
(
inner_layer
,
self
.
weights_dict
)
return
n_weights
official/vision/beta/projects/centernet/utils/checkpoints/config_data.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configurations for loading checkpoints."""
import
dataclasses
from
typing
import
Dict
,
Optional
import
numpy
as
np
from
official.vision.beta.projects.centernet.utils.checkpoints
import
config_classes
Conv2DBNCFG
=
config_classes
.
Conv2DBNCFG
HeadConvCFG
=
config_classes
.
HeadConvCFG
ResidualBlockCFG
=
config_classes
.
ResidualBlockCFG
HourglassCFG
=
config_classes
.
HourglassCFG
@
dataclasses
.
dataclass
class
BackboneConfigData
:
"""Backbone Config."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
get_cfg_list
(
self
,
name
):
"""Get list of block configs for the module."""
if
name
==
'hourglass104_512'
:
return
[
# Downsampling Layers
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'downsample_input'
][
'conv_block'
]),
ResidualBlockCFG
(
weights_dict
=
self
.
weights_dict
[
'downsample_input'
][
'residual_block'
]),
# Hourglass
HourglassCFG
(
weights_dict
=
self
.
weights_dict
[
'hourglass_network'
][
'0'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'output_conv'
][
'0'
]),
# Intermediate
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_conv1'
][
'0'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_conv2'
][
'0'
]),
ResidualBlockCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_residual'
][
'0'
]),
# Hourglass
HourglassCFG
(
weights_dict
=
self
.
weights_dict
[
'hourglass_network'
][
'1'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'output_conv'
][
'1'
]),
]
elif
name
==
'extremenet'
:
return
[
# Downsampling Layers
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'downsample_input'
][
'conv_block'
]),
ResidualBlockCFG
(
weights_dict
=
self
.
weights_dict
[
'downsample_input'
][
'residual_block'
]),
# Hourglass
HourglassCFG
(
weights_dict
=
self
.
weights_dict
[
'hourglass_network'
][
'0'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'output_conv'
][
'0'
]),
# Intermediate
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_conv1'
][
'0'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_conv2'
][
'0'
]),
ResidualBlockCFG
(
weights_dict
=
self
.
weights_dict
[
'intermediate_residual'
][
'0'
]),
# Hourglass
HourglassCFG
(
weights_dict
=
self
.
weights_dict
[
'hourglass_network'
][
'1'
]),
Conv2DBNCFG
(
weights_dict
=
self
.
weights_dict
[
'output_conv'
][
'1'
]),
]
@
dataclasses
.
dataclass
class
HeadConfigData
:
"""Head Config."""
weights_dict
:
Optional
[
Dict
[
str
,
np
.
ndarray
]]
=
dataclasses
.
field
(
repr
=
False
,
default
=
None
)
def
get_cfg_list
(
self
,
name
):
if
name
==
'detection_2d'
:
return
[
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'object_center'
][
'0'
]),
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'object_center'
][
'1'
]),
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'box.Soffset'
][
'0'
]),
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'box.Soffset'
][
'1'
]),
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'box.Sscale'
][
'0'
]),
HeadConvCFG
(
weights_dict
=
self
.
weights_dict
[
'box.Sscale'
][
'1'
])
]
official/vision/beta/projects/centernet/utils/checkpoints/load_weights.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions used to load the ODAPI CenterNet checkpoint."""
from
official.vision.beta.modeling.backbones
import
mobilenet
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.projects.centernet.modeling.layers
import
cn_nn_blocks
from
official.vision.beta.projects.centernet.utils.checkpoints
import
config_classes
from
official.vision.beta.projects.centernet.utils.checkpoints
import
config_data
Conv2DBNCFG
=
config_classes
.
Conv2DBNCFG
HeadConvCFG
=
config_classes
.
HeadConvCFG
ResidualBlockCFG
=
config_classes
.
ResidualBlockCFG
HourglassCFG
=
config_classes
.
HourglassCFG
BackboneConfigData
=
config_data
.
BackboneConfigData
HeadConfigData
=
config_data
.
HeadConfigData
def
get_backbone_layer_cfgs
(
weights_dict
,
backbone_name
):
"""Fetches the config classes for the backbone.
This function generates a list of config classes corresponding to
each building block in the backbone.
Args:
weights_dict: Dictionary that stores the backbone model weights.
backbone_name: String, indicating the desired backbone configuration.
Returns:
A list containing the config classe of the backbone building block
"""
print
(
"Fetching backbone config classes for {}
\n
"
.
format
(
backbone_name
))
cfgs
=
BackboneConfigData
(
weights_dict
=
weights_dict
).
get_cfg_list
(
backbone_name
)
return
cfgs
def
load_weights_backbone
(
backbone
,
weights_dict
,
backbone_name
):
"""Loads the weights defined in the weights_dict into the backbone.
This function loads the backbone weights by first fetching the necessary
config classes for the backbone, then loads them in one by one for
each layer that has weights associated with it.
Args:
backbone: keras.Model backbone.
weights_dict: Dictionary that stores the backbone model weights.
backbone_name: String, indicating the desired backbone configuration.
Returns:
Number of weights loaded in
"""
print
(
"Loading backbone weights
\n
"
)
backbone_layers
=
backbone
.
layers
cfgs
=
get_backbone_layer_cfgs
(
weights_dict
,
backbone_name
)
n_weights_total
=
0
cfg
=
cfgs
.
pop
(
0
)
for
i
in
range
(
len
(
backbone_layers
)):
layer
=
backbone_layers
[
i
]
if
isinstance
(
layer
,
(
mobilenet
.
Conv2DBNBlock
,
cn_nn_blocks
.
HourglassBlock
,
nn_blocks
.
ResidualBlock
)):
n_weights
=
cfg
.
load_weights
(
layer
)
print
(
"Loading weights for: {}, weights loaded: {}"
.
format
(
cfg
,
n_weights
))
n_weights_total
+=
n_weights
# pylint: disable=g-explicit-length-test
if
len
(
cfgs
)
==
0
:
print
(
"{} Weights have been loaded for {} / {} layers
\n
"
.
format
(
n_weights_total
,
i
+
1
,
len
(
backbone_layers
)))
return
n_weights_total
cfg
=
cfgs
.
pop
(
0
)
return
n_weights_total
def
get_head_layer_cfgs
(
weights_dict
,
head_name
):
"""Fetches the config classes for the head.
This function generates a list of config classes corresponding to
each building block in the head.
Args:
weights_dict: Dictionary that stores the decoder model weights.
head_name: String, indicating the desired head configuration.
Returns:
A list containing the config classes of the backbone building block
"""
print
(
"Fetching head config classes for {}
\n
"
.
format
(
head_name
))
cfgs
=
HeadConfigData
(
weights_dict
=
weights_dict
).
get_cfg_list
(
head_name
)
return
cfgs
def
load_weights_head
(
head
,
weights_dict
,
head_name
):
"""Loads the weights defined in the weights_dict into the head.
This function loads the head weights by first fetching the necessary
config classes for the decoder, then loads them in one by one for
each layer that has weights associated with it.
Args:
head: keras.Model head.
weights_dict: Dictionary that stores the decoder model weights.
head_name: String, indicating the desired head configuration.
Returns:
Number of weights loaded in
"""
print
(
"Loading head weights
\n
"
)
head_layers
=
head
.
layers
cfgs
=
get_head_layer_cfgs
(
weights_dict
,
head_name
)
n_weights_total
=
0
cfg
=
cfgs
.
pop
(
0
)
for
i
in
range
(
len
(
head_layers
)):
layer
=
head_layers
[
i
]
if
isinstance
(
layer
,
cn_nn_blocks
.
CenterNetHeadConv
):
n_weights
=
cfg
.
load_weights
(
layer
)
print
(
"Loading weights for: {}, weights loaded: {}"
.
format
(
cfg
,
n_weights
))
n_weights_total
+=
n_weights
# pylint: disable=g-explicit-length-test
if
len
(
cfgs
)
==
0
:
print
(
"{} Weights have been loaded for {} / {} layers
\n
"
.
format
(
n_weights_total
,
i
+
1
,
len
(
head_layers
)))
return
n_weights_total
cfg
=
cfgs
.
pop
(
0
)
return
n_weights_total
def
load_weights_model
(
model
,
weights_dict
,
backbone_name
,
head_name
):
"""Loads weights into the model.
Args:
model: keras.Model to load weights into.
weights_dict: Dictionary that stores the weights of the model.
backbone_name: String, indicating the desired backbone configuration.
head_name: String, indicating the desired head configuration.
Returns:
"""
print
(
"Loading model weights
\n
"
)
n_weights
=
0
if
backbone_name
:
n_weights
+=
load_weights_backbone
(
model
.
backbone
,
weights_dict
[
"model"
][
"_feature_extractor"
][
"_network"
],
backbone_name
)
if
head_name
:
n_weights
+=
load_weights_head
(
model
.
head
,
weights_dict
[
"model"
][
"_prediction_head_dict"
],
head_name
)
print
(
"Successfully loaded {} model weights.
\n
"
.
format
(
n_weights
))
return
model
official/vision/beta/projects/centernet/utils/checkpoints/read_checkpoints.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions used to convert a TF checkpoint into a dictionary."""
import
numpy
as
np
import
tensorflow
as
tf
def
update_weights_dict
(
weights_dict
,
variable_key
,
value
):
"""Inserts weight value into a weight dictionary.
This function inserts a weight value into a weights_dict based on the
variable key. It is designed to organize TF checkpoint weights by organizing
them by submodules.
Args:
weights_dict: Dictionary to store weights.
variable_key: String, name of the variable assocaited with the value.
value: An ndarray that stores the weights assocaited to the variable key.
"""
current_dict
=
weights_dict
variable_key_list
=
variable_key
.
split
(
"/"
)
key
=
variable_key_list
.
pop
(
0
)
# pylint: disable=g-explicit-length-test
while
len
(
variable_key_list
):
if
variable_key_list
[
0
]
==
".ATTRIBUTES"
:
current_dict
[
key
]
=
value
return
if
key
not
in
current_dict
.
keys
():
current_dict
[
key
]
=
{}
current_dict
=
current_dict
[
key
]
key
=
variable_key_list
.
pop
(
0
)
def
get_ckpt_weights_as_dict
(
ckpt_path
):
"""Converts a TF checkpoint into a nested dictionary of weights.
Args:
ckpt_path: String, indicating filepath of the TF checkpoint
Returns:
Dictionary where the checkpoint weights are stored
Number of weights read
"""
print
(
"
\n
Converting model checkpoint from {} to weights dictionary
\n
"
.
format
(
ckpt_path
))
reader
=
tf
.
train
.
load_checkpoint
(
ckpt_path
)
shape_from_key
=
reader
.
get_variable_to_shape_map
()
# dtype_from_key = reader.get_variable_to_dtype_map()
variable_keys
=
shape_from_key
.
keys
()
weights_dict
=
{}
n_read
=
0
for
key
in
variable_keys
:
# shape = shape_from_key[key]
# dtype = dtype_from_key[key]
value
=
reader
.
get_tensor
(
key
)
n_read
+=
tf
.
size
(
value
)
update_weights_dict
(
weights_dict
,
key
,
value
)
print
(
"Successfully read {} checkpoint weights
\n
"
.
format
(
n_read
))
return
weights_dict
,
n_read
def
write_dict_as_tree
(
dictionary
,
filename
,
spaces
=
0
):
"""Writes nested dictionary keys to a file.
Given a dictionary that contains nested dictionaries, this function
writes the name of the keys recursively to a specified file as a tree
Args:
dictionary: Desired dictionary to write to a file
filename: String, name of file to write dictionary to
spaces: Optional; Number of spaces to insert before writing
the dictionary key names
"""
if
isinstance
(
dictionary
,
dict
):
mode
=
"w"
if
spaces
==
0
else
"a"
for
key
in
dictionary
.
keys
():
with
open
(
filename
,
mode
)
as
fp
:
fp
.
write
(
" "
*
spaces
+
key
+
"
\n
"
)
mode
=
"a"
write_dict_as_tree
(
dictionary
[
key
],
filename
,
spaces
+
2
)
def
print_layer_weights_and_shape
(
layer
):
"""Prints variables information corresponding to a Keras layer.
This function prints the name and the shape of its associated weights
of all variables (trainable and untrainable) in a Keras layer.
Args:
layer: A Keras.layer.Layer object
"""
weights
=
layer
.
get_weights
()
variables
=
layer
.
variables
for
i
in
range
(
len
(
weights
)):
tf
.
print
(
np
.
shape
(
weights
[
i
]),
variables
[
i
].
name
)
official/vision/beta/projects/centernet/utils/tf2_centernet_checkpoint_converter.py
0 → 100644
View file @
460890ed
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A converter from a tf1 OD API checkpoint to a tf2 checkpoint."""
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.centernet.common
import
registry_imports
# pylint: disable=unused-import
from
official.vision.beta.projects.centernet.configs
import
backbones
from
official.vision.beta.projects.centernet.configs
import
centernet
from
official.vision.beta.projects.centernet.modeling
import
centernet_model
from
official.vision.beta.projects.centernet.modeling.heads
import
centernet_head
from
official.vision.beta.projects.centernet.modeling.layers
import
detection_generator
from
official.vision.beta.projects.centernet.utils.checkpoints
import
load_weights
from
official.vision.beta.projects.centernet.utils.checkpoints
import
read_checkpoints
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"checkpoint_to_convert"
,
None
,
"Initial checkpoint from a pretrained model."
)
flags
.
DEFINE_string
(
"checkpoint_backbone_name"
,
"hourglass104_512"
,
"IIndicate the desired backbone configuration."
)
flags
.
DEFINE_string
(
"checkpoint_head_name"
,
"detection_2d"
,
"Indicate the desired head configuration."
)
flags
.
DEFINE_string
(
"converted_checkpoint_path"
,
None
,
"Output path of converted checkpoint."
)
flags
.
DEFINE_integer
(
"hourglass_id"
,
52
,
"Model id of hourglass backbone."
)
flags
.
DEFINE_integer
(
"num_hourglasses"
,
2
,
"Number of hourglass blocks in backbone."
)
def
_create_centernet_model
(
model_id
:
int
=
52
,
num_hourglasses
:
int
=
2
)
->
centernet_model
.
CenterNetModel
:
"""Create centernet model to load TF1 weights."""
task_config
=
centernet
.
CenterNetTask
(
model
=
centernet
.
CenterNetModel
(
backbone
=
backbones
.
Backbone
(
type
=
"hourglass"
,
hourglass
=
backbones
.
Hourglass
(
model_id
=
model_id
,
num_hourglasses
=
num_hourglasses
))))
model_config
=
task_config
.
model
backbone
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
1
,
512
,
512
,
3
]),
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
)
task_outputs
=
task_config
.
get_output_length_dict
()
head
=
centernet_head
.
CenterNetHead
(
input_specs
=
backbone
.
output_specs
,
task_outputs
=
task_outputs
,
input_levels
=
model_config
.
head
.
input_levels
)
detect_generator_obj
=
detection_generator
.
CenterNetDetectionGenerator
()
model
=
centernet_model
.
CenterNetModel
(
backbone
=
backbone
,
head
=
head
,
detection_generator
=
detect_generator_obj
)
logging
.
info
(
"Successfully created centernet model."
)
return
model
def
_load_weights
(
model
:
centernet_model
.
CenterNetModel
,
ckpt_dir_or_file
:
str
,
ckpt_backbone_name
:
str
,
ckpt_head_name
:
str
):
"""Read TF1 checkpoint and load the weights to centernet model."""
weights_dict
,
_
=
read_checkpoints
.
get_ckpt_weights_as_dict
(
ckpt_dir_or_file
)
load_weights
.
load_weights_model
(
model
=
model
,
weights_dict
=
weights_dict
,
backbone_name
=
ckpt_backbone_name
,
head_name
=
ckpt_head_name
)
def
_save_checkpoint
(
model
:
centernet_model
.
CenterNetModel
,
ckpt_dir
:
str
):
"""Save the TF2 centernet model checkpoint."""
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
,
**
model
.
checkpoint_items
)
manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
ckpt_dir
,
max_to_keep
=
3
)
manager
.
save
()
logging
.
info
(
"Save checkpoint to %s."
,
ckpt_dir
)
def
convert_checkpoint
(
model_id
:
int
,
num_hourglasses
:
int
,
ckpt_dir_or_file
:
str
,
ckpt_backbone_name
:
str
,
ckpt_head_name
:
str
,
output_ckpt_dir
:
str
):
"""Convert the TF1 OD API checkpoint to a tf2 checkpoint."""
model
=
_create_centernet_model
(
model_id
=
model_id
,
num_hourglasses
=
num_hourglasses
)
_load_weights
(
model
=
model
,
ckpt_dir_or_file
=
ckpt_dir_or_file
,
ckpt_backbone_name
=
ckpt_backbone_name
,
ckpt_head_name
=
ckpt_head_name
)
_save_checkpoint
(
model
=
model
,
ckpt_dir
=
output_ckpt_dir
)
def
main
(
_
):
convert_checkpoint
(
model_id
=
FLAGS
.
hourglass_id
,
num_hourglasses
=
FLAGS
.
num_hourglasses
,
ckpt_dir_or_file
=
FLAGS
.
checkpoint_to_convert
,
ckpt_backbone_name
=
FLAGS
.
checkpoint_backbone_name
,
ckpt_head_name
=
FLAGS
.
checkpoint_head_name
,
output_ckpt_dir
=
FLAGS
.
converted_checkpoint_path
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment