Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e4be7e00
Commit
e4be7e00
authored
Mar 25, 2022
by
Yeqing Li
Committed by
A. Unique TensorFlower
Mar 25, 2022
Browse files
Removes unneeded content of the beta folder.
PiperOrigin-RevId: 437276665
parent
f47405b5
Changes
235
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
0 additions
and
2885 deletions
+0
-2885
official/vision/beta/serving/export_utils.py
official/vision/beta/serving/export_utils.py
+0
-121
official/vision/beta/serving/image_classification.py
official/vision/beta/serving/image_classification.py
+0
-83
official/vision/beta/serving/image_classification_test.py
official/vision/beta/serving/image_classification_test.py
+0
-120
official/vision/beta/serving/semantic_segmentation.py
official/vision/beta/serving/semantic_segmentation.py
+0
-89
official/vision/beta/serving/semantic_segmentation_test.py
official/vision/beta/serving/semantic_segmentation_test.py
+0
-114
official/vision/beta/serving/video_classification.py
official/vision/beta/serving/video_classification.py
+0
-190
official/vision/beta/serving/video_classification_test.py
official/vision/beta/serving/video_classification_test.py
+0
-112
official/vision/beta/tasks/__init__.py
official/vision/beta/tasks/__init__.py
+0
-21
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+0
-312
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+0
-455
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+0
-358
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+0
-337
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+0
-353
official/vision/beta/train.py
official/vision/beta/train.py
+0
-69
official/vision/beta/train_spatial_partitioning.py
official/vision/beta/train_spatial_partitioning.py
+0
-151
No files found.
official/vision/beta/serving/export_utils.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper utils for export library."""
from
typing
import
List
,
Optional
import
tensorflow
as
tf
# pylint: disable=g-long-lambda
def
get_image_input_signatures
(
input_type
:
str
,
batch_size
:
Optional
[
int
],
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
):
"""Gets input signatures for an image.
Args:
input_type: A `str`, can be either tf_example, image_bytes, or image_tensor.
batch_size: `int` for batch size or None.
input_image_size: List[int] for the height and width of the input image.
num_channels: `int` for number of channels in the input image.
Returns:
tf.TensorSpec of the input tensor.
"""
if
input_type
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
]
+
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
)
elif
input_type
in
[
'image_bytes'
,
'serve_examples'
,
'tf_example'
]:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
],
dtype
=
tf
.
string
)
elif
input_type
==
'tflite'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
1
]
+
input_image_size
+
[
num_channels
],
dtype
=
tf
.
float32
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
input_signature
def
decode_image
(
encoded_image_bytes
:
str
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
,)
->
tf
.
Tensor
:
"""Decodes an image bytes to an image tensor.
Use `tf.image.decode_image` to decode an image if input is expected to be 2D
image; otherwise use `tf.io.decode_raw` to convert the raw bytes to tensor
and reshape it to desire shape.
Args:
encoded_image_bytes: An encoded image string to be decoded.
input_image_size: List[int] for the desired input size. This will be used to
infer whether the image is 2d or 3d.
num_channels: `int` for number of image channels.
Returns:
A decoded image tensor.
"""
if
len
(
input_image_size
)
==
2
:
# Decode an image if 2D input is expected.
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_bytes
,
channels
=
num_channels
)
else
:
# Convert raw bytes into a tensor and reshape it, if not 2D input.
image_tensor
=
tf
.
io
.
decode_raw
(
encoded_image_bytes
,
out_type
=
tf
.
uint8
)
image_tensor
.
set_shape
([
None
]
*
len
(
input_image_size
)
+
[
num_channels
])
return
image_tensor
def
decode_image_tf_example
(
tf_example_string_tensor
:
tf
.
train
.
Example
,
input_image_size
:
List
[
int
],
num_channels
:
int
=
3
,
encoded_key
:
str
=
'image/encoded'
)
->
tf
.
Tensor
:
"""Decodes a TF Example to an image tensor."""
keys_to_features
=
{
encoded_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
}
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
keys_to_features
)
image_tensor
=
decode_image
(
parsed_tensors
[
encoded_key
],
input_image_size
=
input_image_size
,
num_channels
=
num_channels
)
return
image_tensor
def
parse_image
(
inputs
,
input_type
:
str
,
input_image_size
:
List
[
int
],
num_channels
:
int
):
"""Parses image."""
if
input_type
in
[
'tf_example'
,
'serve_examples'
]:
decode_image_tf_example_fn
=
(
lambda
x
:
decode_image_tf_example
(
x
,
input_image_size
,
num_channels
))
image_tensor
=
tf
.
map_fn
(
decode_image_tf_example_fn
,
elems
=
inputs
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
),
)
elif
input_type
==
'image_bytes'
:
decode_image_fn
=
lambda
x
:
decode_image
(
x
,
input_image_size
,
num_channels
)
image_tensor
=
tf
.
map_fn
(
decode_image_fn
,
elems
=
inputs
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
None
]
*
len
(
input_image_size
)
+
[
num_channels
],
dtype
=
tf
.
uint8
),)
else
:
image_tensor
=
inputs
return
image_tensor
official/vision/beta/serving/image_classification.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classification input and model functions for serving/inference."""
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.serving
import
export_base
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
ClassificationModule
(
export_base
.
ExportModule
):
"""classification Module."""
def
_build_model
(
self
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
return
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
,
l2_regularizer
=
None
)
def
_build_inputs
(
self
,
image
):
"""Builds classification model inputs for serving."""
# Center crops and resizes image.
image
=
preprocess_ops
.
center_crop_image
(
image
)
image
=
tf
.
image
.
resize
(
image
,
self
.
_input_image_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
reshape
(
image
,
[
self
.
_input_image_size
[
0
],
self
.
_input_image_size
[
1
],
3
])
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
return
image
def
serve
(
self
,
images
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if
self
.
_input_type
!=
'tflite'
:
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
),
parallel_iterations
=
32
))
logits
=
self
.
inference_step
(
images
)
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
official/vision/beta/serving/image_classification_test.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for image classification export lib."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.vision.beta
import
configs
# pylint: disable=unused-import
from
official.vision.beta.serving
import
image_classification
class
ImageClassificationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
'resnet_imagenet'
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
classification_module
=
image_classification
.
ClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
224
,
224
],
input_type
=
input_type
)
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
{
'input_type'
:
'tflite'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
(
input_type
)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module
.
model
.
test_trackable
=
tf
.
keras
.
layers
.
InputLayer
(
input_shape
=
(
4
,))
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
if
input_type
!=
'tflite'
:
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_build_inputs
,
elems
=
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
224
,
224
,
3
],
dtype
=
tf
.
float32
)))
else
:
processed_images
=
images
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertTrue
(
hasattr
(
imported
.
model
,
'test_trackable'
))
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
())
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/semantic_segmentation.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semantic segmentation input and model functions for serving/inference."""
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.serving
import
export_base
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
SegmentationModule
(
export_base
.
ExportModule
):
"""Segmentation Module."""
def
_build_model
(
self
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
return
factory
.
build_segmentation_model
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
,
l2_regularizer
=
None
)
def
_build_inputs
(
self
,
image
):
"""Builds classification model inputs for serving."""
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
,
image_info
=
preprocess_ops
.
resize_and_crop_image
(
image
,
self
.
_input_image_size
,
padded_size
=
self
.
_input_image_size
,
aug_scale_min
=
1.0
,
aug_scale_max
=
1.0
)
return
image
,
image_info
def
serve
(
self
,
images
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
image_info
=
None
if
self
.
_input_type
!=
'tflite'
:
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images_spec
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
)
image_info_spec
=
tf
.
TensorSpec
(
shape
=
[
4
,
2
],
dtype
=
tf
.
float32
)
images
,
image_info
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
(
images_spec
,
image_info_spec
),
parallel_iterations
=
32
))
outputs
=
self
.
inference_step
(
images
)
outputs
[
'logits'
]
=
tf
.
image
.
resize
(
outputs
[
'logits'
],
self
.
_input_image_size
,
method
=
'bilinear'
)
if
image_info
is
not
None
:
outputs
.
update
({
'image_info'
:
image_info
})
return
outputs
official/vision/beta/serving/semantic_segmentation_test.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for semantic segmentation export lib."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.vision.beta
import
configs
# pylint: disable=unused-import
from
official.vision.beta.serving
import
semantic_segmentation
class
SemanticSegmentationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_segmentation_module
(
self
,
input_type
):
params
=
exp_factory
.
get_exp_config
(
'mnv2_deeplabv3_pascal'
)
segmentation_module
=
semantic_segmentation
.
SegmentationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
112
,
112
],
input_type
=
input_type
)
return
segmentation_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
1
,
112
,
112
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
112
,
112
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
112
,
112
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
elif
input_type
==
'tflite'
:
return
tf
.
zeros
((
1
,
112
,
112
,
3
),
dtype
=
np
.
float32
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
{
'input_type'
:
'tflite'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_segmentation_module
(
input_type
)
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
segmentation_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
if
input_type
!=
'tflite'
:
processed_images
,
_
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_build_inputs
,
elems
=
tf
.
zeros
((
1
,
112
,
112
,
3
),
dtype
=
tf
.
uint8
),
fn_output_signature
=
(
tf
.
TensorSpec
(
shape
=
[
112
,
112
,
3
],
dtype
=
tf
.
float32
),
tf
.
TensorSpec
(
shape
=
[
4
,
2
],
dtype
=
tf
.
float32
))))
else
:
processed_images
=
images
expected_output
=
tf
.
image
.
resize
(
module
.
model
(
processed_images
,
training
=
False
)[
'logits'
],
[
112
,
112
],
method
=
'bilinear'
)
out
=
segmentation_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_output
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/video_classification.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video classification input and model functions for serving/inference."""
from
typing
import
Mapping
,
Dict
,
Text
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.serving
import
export_base
from
official.vision.beta.tasks
import
video_classification
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
VideoClassificationModule
(
export_base
.
ExportModule
):
"""Video classification Module."""
def
_build_model
(
self
):
input_params
=
self
.
params
.
task
.
train_data
self
.
_num_frames
=
input_params
.
feature_shape
[
0
]
self
.
_stride
=
input_params
.
temporal_stride
self
.
_min_resize
=
input_params
.
min_image_size
self
.
_crop_size
=
input_params
.
feature_shape
[
1
]
self
.
_output_audio
=
input_params
.
output_audio
task
=
video_classification
.
VideoClassificationTask
(
self
.
params
.
task
)
return
task
.
build_model
()
def
_decode_tf_example
(
self
,
encoded_inputs
:
tf
.
Tensor
):
sequence_description
=
{
# Each image is a string encoding JPEG.
video_input
.
IMAGE_KEY
:
tf
.
io
.
FixedLenSequenceFeature
((),
tf
.
string
),
}
if
self
.
_output_audio
:
sequence_description
[
self
.
_params
.
task
.
validation_data
.
audio_feature
]
=
(
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
))
_
,
decoded_tensors
=
tf
.
io
.
parse_single_sequence_example
(
encoded_inputs
,
{},
sequence_description
)
for
key
,
value
in
decoded_tensors
.
items
():
if
isinstance
(
value
,
tf
.
SparseTensor
):
decoded_tensors
[
key
]
=
tf
.
sparse
.
to_dense
(
value
)
return
decoded_tensors
def
_preprocess_image
(
self
,
image
):
image
=
video_input
.
process_image
(
image
=
image
,
is_training
=
False
,
num_frames
=
self
.
_num_frames
,
stride
=
self
.
_stride
,
num_test_clips
=
1
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
1
)
image
=
tf
.
cast
(
image
,
tf
.
float32
)
# Use config.
features
=
{
'image'
:
image
}
return
features
def
_preprocess_audio
(
self
,
audio
):
features
=
{}
audio
=
tf
.
cast
(
audio
,
dtype
=
tf
.
float32
)
# Use config.
audio
=
video_input
.
preprocess_ops_3d
.
sample_sequence
(
audio
,
20
,
random
=
False
,
stride
=
1
)
audio
=
tf
.
ensure_shape
(
audio
,
self
.
_params
.
task
.
validation_data
.
audio_feature_shape
)
features
[
'audio'
]
=
audio
return
features
@
tf
.
function
def
inference_from_tf_example
(
self
,
encoded_inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
with
tf
.
device
(
'cpu:0'
):
if
self
.
_output_audio
:
inputs
=
tf
.
map_fn
(
self
.
_decode_tf_example
,
(
encoded_inputs
),
fn_output_signature
=
{
video_input
.
IMAGE_KEY
:
tf
.
string
,
self
.
_params
.
task
.
validation_data
.
audio_feature
:
tf
.
float32
})
return
self
.
serve
(
inputs
[
'image'
],
inputs
[
'audio'
])
else
:
inputs
=
tf
.
map_fn
(
self
.
_decode_tf_example
,
(
encoded_inputs
),
fn_output_signature
=
{
video_input
.
IMAGE_KEY
:
tf
.
string
,
})
return
self
.
serve
(
inputs
[
video_input
.
IMAGE_KEY
],
tf
.
zeros
([
1
,
1
]))
@
tf
.
function
def
inference_from_image_tensors
(
self
,
input_frames
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
input_frames
,
tf
.
zeros
([
1
,
1
]))
@
tf
.
function
def
inference_from_image_audio_tensors
(
self
,
input_frames
:
tf
.
Tensor
,
input_audio
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
input_frames
,
input_audio
)
@
tf
.
function
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
raise
NotImplementedError
(
'Video classification do not support image bytes input.'
)
def
serve
(
self
,
input_frames
:
tf
.
Tensor
,
input_audio
:
tf
.
Tensor
):
"""Cast image to float and run inference.
Args:
input_frames: uint8 Tensor of shape [batch_size, None, None, 3]
input_audio: float32
Returns:
Tensor holding classification output logits.
"""
with
tf
.
device
(
'cpu:0'
):
inputs
=
tf
.
map_fn
(
self
.
_preprocess_image
,
(
input_frames
),
fn_output_signature
=
{
'image'
:
tf
.
float32
,
})
if
self
.
_output_audio
:
inputs
.
update
(
tf
.
map_fn
(
self
.
_preprocess_audio
,
(
input_audio
),
fn_output_signature
=
{
'audio'
:
tf
.
float32
}))
logits
=
self
.
inference_step
(
inputs
)
if
self
.
params
.
task
.
train_data
.
is_multilabel
:
probs
=
tf
.
math
.
sigmoid
(
logits
)
else
:
probs
=
tf
.
nn
.
softmax
(
logits
)
return
{
'logits'
:
logits
,
'probs'
:
probs
}
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures
=
{}
for
key
,
def_name
in
function_keys
.
items
():
if
key
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
,
name
=
'INPUT_FRAMES'
)
signatures
[
def_name
]
=
self
.
inference_from_image_tensors
.
get_concrete_function
(
input_signature
)
elif
key
==
'frames_audio'
:
input_signature
=
[
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
,
name
=
'INPUT_FRAMES'
),
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
params
.
task
.
train_data
.
audio_feature_shape
,
dtype
=
tf
.
float32
,
name
=
'INPUT_AUDIO'
)
]
signatures
[
def_name
]
=
self
.
inference_from_image_audio_tensors
.
get_concrete_function
(
input_signature
)
elif
key
==
'serve_examples'
or
key
==
'tf_example'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
],
dtype
=
tf
.
string
)
signatures
[
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
return
signatures
official/vision/beta/serving/video_classification_test.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import io
import
os
import
random
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.vision.beta
import
configs
# pylint: disable=unused-import
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
video_classification
class
VideoClassificationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
):
params
=
exp_factory
.
get_exp_config
(
'video_classification_ucf101'
)
params
.
task
.
train_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
validation_data
.
feature_shape
=
(
8
,
64
,
64
,
3
)
params
.
task
.
model
.
backbone
.
resnet_3d
.
model_id
=
50
classification_module
=
video_classification
.
VideoClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
8
,
64
,
64
])
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
,
module
=
None
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
images
=
np
.
random
.
randint
(
low
=
0
,
high
=
255
,
size
=
(
1
,
8
,
64
,
64
,
3
),
dtype
=
np
.
uint8
)
# images = np.zeros((1, 8, 64, 64, 3), dtype=np.uint8)
return
images
,
images
elif
input_type
==
'tf_example'
:
example
=
tfexample_utils
.
make_video_test_example
(
image_shape
=
(
64
,
64
,
3
),
audio_shape
=
(
20
,
128
),
label
=
random
.
randint
(
0
,
100
)).
SerializeToString
()
images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_decode_tf_example
,
elems
=
tf
.
constant
([
example
]),
fn_output_signature
=
{
video_classification
.
video_input
.
IMAGE_KEY
:
tf
.
string
,
}))
images
=
images
[
video_classification
.
video_input
.
IMAGE_KEY
]
return
[
example
],
images
else
:
raise
ValueError
(
f
'
{
input_type
}
'
)
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
,
images_tensor
=
self
.
_get_dummy_input
(
input_type
,
module
)
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_preprocess_image
,
elems
=
images_tensor
,
fn_output_signature
=
{
'image'
:
tf
.
float32
,
}))
expected_logits
=
module
.
model
(
processed_images
,
training
=
False
)
expected_prob
=
tf
.
nn
.
softmax
(
expected_logits
)
out
=
classification_fn
(
tf
.
constant
(
images
))
# The imported model should contain any trackable attrs that the original
# model had.
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_logits
.
numpy
())
self
.
assertAllClose
(
out
[
'probs'
].
numpy
(),
expected_prob
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/__init__.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tasks package definition."""
from
official.vision.beta.tasks
import
image_classification
from
official.vision.beta.tasks
import
maskrcnn
from
official.vision.beta.tasks
import
retinanet
from
official.vision.beta.tasks
import
semantic_segmentation
from
official.vision.beta.tasks
import
video_classification
official/vision/beta/tasks/image_classification.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image classification task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.vision.beta.configs
import
image_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
augment
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
class
ImageClassificationTask
(
base_task
.
Task
):
"""A task for image classification."""
def
build_model
(
self
):
"""Builds classification model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loads pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Builds classification input."""
num_classes
=
self
.
task_config
.
model
.
num_classes
input_size
=
self
.
task_config
.
model
.
input_size
image_field_key
=
self
.
task_config
.
train_data
.
image_field_key
label_field_key
=
self
.
task_config
.
train_data
.
label_field_key
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_classification_decoder
(
params
.
tfds_name
)
else
:
decoder
=
classification_input
.
Decoder
(
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
is_multilabel
=
is_multilabel
)
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
postprocess_fn
=
None
if
params
.
mixup_and_cutmix
:
postprocess_fn
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
params
.
mixup_and_cutmix
.
mixup_alpha
,
cutmix_alpha
=
params
.
mixup_and_cutmix
.
cutmix_alpha
,
prob
=
params
.
mixup_and_cutmix
.
prob
,
label_smoothing
=
params
.
mixup_and_cutmix
.
label_smoothing
,
num_classes
=
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
"""Builds sparse categorical cross entropy loss.
Args:
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
The total loss tensor.
"""
losses_config
=
self
.
task_config
.
losses
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
elif
losses_config
.
soft_labels
:
total_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
model_outputs
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
else
:
# Multi-label weighted binary cross entropy loss.
total_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
labels
,
logits
=
model_outputs
)
total_loss
=
tf
.
reduce_sum
(
total_loss
,
axis
=-
1
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
total_loss
=
losses_config
.
loss_weight
*
total_loss
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
List
[
tf
.
keras
.
metrics
.
Metric
]:
"""Gets streaming metrics for training/validation."""
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
k
=
self
.
task_config
.
evaluation
.
top_k
if
(
self
.
task_config
.
losses
.
one_hot
or
self
.
task_config
.
losses
.
soft_labels
):
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[]
# These metrics destablize the training if included in training. The jobs
# fail due to OOM.
# TODO(arashwan): Investigate adding following metric to train.
if
not
training
:
metrics
=
[
tf
.
keras
.
metrics
.
AUC
(
name
=
'globalPR-AUC'
,
curve
=
'PR'
,
multi_label
=
False
,
from_logits
=
True
),
tf
.
keras
.
metrics
.
AUC
(
name
=
'meanPR-AUC'
,
curve
=
'PR'
,
multi_label
=
True
,
num_labels
=
self
.
task_config
.
model
.
num_classes
,
from_logits
=
True
),
]
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Runs validatation step.
Args:
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
one_hot
=
self
.
task_config
.
losses
.
one_hot
soft_labels
=
self
.
task_config
.
losses
.
soft_labels
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
(
one_hot
or
soft_labels
)
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
inputs
:
tf
.
Tensor
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
official/vision/beta/tasks/maskrcnn.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaskRCNN task definition."""
import
os
from
typing
import
Any
,
Optional
,
List
,
Tuple
,
Mapping
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.vision.beta.configs
import
maskrcnn
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
maskrcnn_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.evaluation
import
coco_utils
from
official.vision.beta.losses
import
maskrcnn_losses
from
official.vision.beta.modeling
import
factory
def
zero_out_disallowed_class_ids
(
batch_class_ids
:
tf
.
Tensor
,
allowed_class_ids
:
List
[
int
]):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
batch_class_ids: A [batch_size, num_instances] int tensor of input
class IDs.
allowed_class_ids: A python list of class IDs which we want to allow.
Returns:
filtered_class_ids: A [batch_size, num_instances] int tensor with any
class ID not in allowed_class_ids set to 0.
"""
allowed_class_ids
=
tf
.
constant
(
allowed_class_ids
,
dtype
=
batch_class_ids
.
dtype
)
match_ids
=
(
batch_class_ids
[:,
:,
tf
.
newaxis
]
==
allowed_class_ids
[
tf
.
newaxis
,
tf
.
newaxis
,
:])
match_ids
=
tf
.
reduce_any
(
match_ids
,
axis
=
2
)
return
tf
.
where
(
match_ids
,
batch_class_ids
,
tf
.
zeros_like
(
batch_class_ids
))
@
task_factory
.
register_task_cls
(
exp_cfg
.
MaskRCNNTask
)
class
MaskRCNNTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
Mask R-CNN task provides artifacts for training/evalution procedures,
including loading/iterating over Datasets, initializing the model, calculating
the loss, post-processing, and customized metrics with reduction.
"""
def
build_model
(
self
):
"""Build Mask R-CNN model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_maskrcnn
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
,
mask_binarize_threshold
=
decoder_cfg
.
mask_binarize_threshold
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
maskrcnn_input
.
Parser
(
output_size
=
self
.
task_config
.
model
.
input_size
[:
2
],
min_level
=
self
.
task_config
.
model
.
min_level
,
max_level
=
self
.
task_config
.
model
.
max_level
,
num_scales
=
self
.
task_config
.
model
.
anchor
.
num_scales
,
aspect_ratios
=
self
.
task_config
.
model
.
anchor
.
aspect_ratios
,
anchor_size
=
self
.
task_config
.
model
.
anchor
.
anchor_size
,
dtype
=
params
.
dtype
,
rpn_match_threshold
=
params
.
parser
.
rpn_match_threshold
,
rpn_unmatched_threshold
=
params
.
parser
.
rpn_unmatched_threshold
,
rpn_batch_size_per_im
=
params
.
parser
.
rpn_batch_size_per_im
,
rpn_fg_fraction
=
params
.
parser
.
rpn_fg_fraction
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
skip_crowd_during_training
=
params
.
parser
.
skip_crowd_during_training
,
max_num_instances
=
params
.
parser
.
max_num_instances
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
mask_crop_size
=
params
.
parser
.
mask_crop_size
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Build Mask R-CNN losses."""
params
=
self
.
task_config
cascade_ious
=
params
.
model
.
roi_sampler
.
cascade_iou_thresholds
rpn_score_loss_fn
=
maskrcnn_losses
.
RpnScoreLoss
(
tf
.
shape
(
outputs
[
'box_outputs'
])[
1
])
rpn_box_loss_fn
=
maskrcnn_losses
.
RpnBoxLoss
(
params
.
losses
.
rpn_huber_loss_delta
)
rpn_score_loss
=
tf
.
reduce_mean
(
rpn_score_loss_fn
(
outputs
[
'rpn_scores'
],
labels
[
'rpn_score_targets'
]))
rpn_box_loss
=
tf
.
reduce_mean
(
rpn_box_loss_fn
(
outputs
[
'rpn_boxes'
],
labels
[
'rpn_box_targets'
]))
frcnn_cls_loss_fn
=
maskrcnn_losses
.
FastrcnnClassLoss
()
frcnn_box_loss_fn
=
maskrcnn_losses
.
FastrcnnBoxLoss
(
params
.
losses
.
frcnn_huber_loss_delta
,
params
.
model
.
detection_head
.
class_agnostic_bbox_pred
)
# Final cls/box losses are computed as an average of all detection heads.
frcnn_cls_loss
=
0.0
frcnn_box_loss
=
0.0
num_det_heads
=
1
if
cascade_ious
is
None
else
1
+
len
(
cascade_ious
)
for
cas_num
in
range
(
num_det_heads
):
frcnn_cls_loss_i
=
tf
.
reduce_mean
(
frcnn_cls_loss_fn
(
outputs
[
'class_outputs_{}'
.
format
(
cas_num
)
if
cas_num
else
'class_outputs'
],
outputs
[
'class_targets_{}'
.
format
(
cas_num
)
if
cas_num
else
'class_targets'
]))
frcnn_box_loss_i
=
tf
.
reduce_mean
(
frcnn_box_loss_fn
(
outputs
[
'box_outputs_{}'
.
format
(
cas_num
)
if
cas_num
else
'box_outputs'
],
outputs
[
'class_targets_{}'
.
format
(
cas_num
)
if
cas_num
else
'class_targets'
],
outputs
[
'box_targets_{}'
.
format
(
cas_num
)
if
cas_num
else
'box_targets'
]))
frcnn_cls_loss
+=
frcnn_cls_loss_i
frcnn_box_loss
+=
frcnn_box_loss_i
frcnn_cls_loss
/=
num_det_heads
frcnn_box_loss
/=
num_det_heads
if
params
.
model
.
include_mask
:
mask_loss_fn
=
maskrcnn_losses
.
MaskrcnnLoss
()
mask_class_targets
=
outputs
[
'mask_class_targets'
]
if
self
.
_task_config
.
allowed_mask_class_ids
is
not
None
:
# Classes with ID=0 are ignored by mask_loss_fn in loss computation.
mask_class_targets
=
zero_out_disallowed_class_ids
(
mask_class_targets
,
self
.
_task_config
.
allowed_mask_class_ids
)
mask_loss
=
tf
.
reduce_mean
(
mask_loss_fn
(
outputs
[
'mask_outputs'
],
outputs
[
'mask_targets'
],
mask_class_targets
))
else
:
mask_loss
=
0.0
model_loss
=
(
params
.
losses
.
rpn_score_weight
*
rpn_score_loss
+
params
.
losses
.
rpn_box_weight
*
rpn_box_loss
+
params
.
losses
.
frcnn_class_weight
*
frcnn_cls_loss
+
params
.
losses
.
frcnn_box_weight
*
frcnn_box_loss
+
params
.
losses
.
mask_weight
*
mask_loss
)
total_loss
=
model_loss
if
aux_losses
:
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
total_loss
=
model_loss
+
reg_loss
total_loss
=
params
.
losses
.
loss_weight
*
total_loss
losses
=
{
'total_loss'
:
total_loss
,
'rpn_score_loss'
:
rpn_score_loss
,
'rpn_box_loss'
:
rpn_box_loss
,
'frcnn_cls_loss'
:
frcnn_cls_loss
,
'frcnn_box_loss'
:
frcnn_box_loss
,
'mask_loss'
:
mask_loss
,
'model_loss'
:
model_loss
,
}
return
losses
def
_build_coco_metrics
(
self
):
"""Build COCO metrics evaluator."""
if
(
not
self
.
_task_config
.
model
.
include_mask
)
or
self
.
_task_config
.
annotation_file
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
_task_config
.
annotation_file
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
else
:
# Builds COCO-style annotation file if include_mask is True, and
# annotation_file isn't provided.
annotation_path
=
os
.
path
.
join
(
self
.
_logging_dir
,
'annotation.json'
)
if
tf
.
io
.
gfile
.
exists
(
annotation_path
):
logging
.
info
(
'annotation.json file exists, skipping creating the annotation'
' file.'
)
else
:
if
self
.
_task_config
.
validation_data
.
num_examples
<=
0
:
logging
.
info
(
'validation_data.num_examples needs to be > 0'
)
if
not
self
.
_task_config
.
validation_data
.
input_path
:
logging
.
info
(
'Can not create annotation file for tfds.'
)
logging
.
info
(
'Creating coco-style annotation file: %s'
,
annotation_path
)
coco_utils
.
scan_and_generator_annotation_file
(
self
.
_task_config
.
validation_data
.
input_path
,
self
.
_task_config
.
validation_data
.
file_type
,
self
.
_task_config
.
validation_data
.
num_examples
,
self
.
task_config
.
model
.
include_mask
,
annotation_path
,
regenerate_source_id
=
self
.
_task_config
.
validation_data
.
decoder
.
simple_decoder
.
regenerate_source_id
)
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
annotation_path
,
include_mask
=
self
.
_task_config
.
model
.
include_mask
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Build detection metrics."""
metrics
=
[]
if
training
:
metric_names
=
[
'total_loss'
,
'rpn_score_loss'
,
'rpn_box_loss'
,
'frcnn_cls_loss'
,
'frcnn_box_loss'
,
'mask_loss'
,
'model_loss'
]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
else
:
if
self
.
_task_config
.
use_coco_metrics
:
self
.
_build_coco_metrics
()
if
self
.
_task_config
.
use_wod_metrics
:
# To use Waymo open dataset metrics, please install one of the pip
# package `waymo-open-dataset-tf-*` from
# https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md#use-pre-compiled-pippip3-packages-for-linux
# Note that the package is built with specific tensorflow version and
# will produce error if it does not match the tf version that is
# currently used.
try
:
from
official.vision.beta.evaluation
import
wod_detection_evaluator
# pylint: disable=g-import-not-at-top
except
ModuleNotFoundError
:
logging
.
error
(
'waymo-open-dataset should be installed to enable Waymo'
' evaluator.'
)
raise
self
.
wod_metric
=
wod_detection_evaluator
.
WOD2dDetectionEvaluator
()
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
images
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
images
,
image_shape
=
labels
[
'image_info'
][:,
1
,
:],
anchor_boxes
=
labels
[
'anchor_boxes'
],
gt_boxes
=
labels
[
'gt_boxes'
],
gt_classes
=
labels
[
'gt_classes'
],
gt_masks
=
(
labels
[
'gt_masks'
]
if
self
.
task_config
.
model
.
include_mask
else
None
),
training
=
True
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
losses
=
self
.
build_losses
(
outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
scaled_loss
=
losses
[
'total_loss'
]
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient when LossScaleOptimizer is used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
losses
[
'total_loss'
]}
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
losses
[
m
.
name
])
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
images
,
labels
=
inputs
outputs
=
model
(
images
,
anchor_boxes
=
labels
[
'anchor_boxes'
],
image_shape
=
labels
[
'image_info'
][:,
1
,
:],
training
=
False
)
logs
=
{
self
.
loss
:
0
}
if
self
.
_task_config
.
use_coco_metrics
:
coco_model_outputs
=
{
'detection_boxes'
:
outputs
[
'detection_boxes'
],
'detection_scores'
:
outputs
[
'detection_scores'
],
'detection_classes'
:
outputs
[
'detection_classes'
],
'num_detections'
:
outputs
[
'num_detections'
],
'source_id'
:
labels
[
'groundtruths'
][
'source_id'
],
'image_info'
:
labels
[
'image_info'
]
}
if
self
.
task_config
.
model
.
include_mask
:
coco_model_outputs
.
update
({
'detection_masks'
:
outputs
[
'detection_masks'
],
})
logs
.
update
(
{
self
.
coco_metric
.
name
:
(
labels
[
'groundtruths'
],
coco_model_outputs
)})
if
self
.
task_config
.
use_wod_metrics
:
wod_model_outputs
=
{
'detection_boxes'
:
outputs
[
'detection_boxes'
],
'detection_scores'
:
outputs
[
'detection_scores'
],
'detection_classes'
:
outputs
[
'detection_classes'
],
'num_detections'
:
outputs
[
'num_detections'
],
'source_id'
:
labels
[
'groundtruths'
][
'source_id'
],
'image_info'
:
labels
[
'image_info'
]
}
logs
.
update
(
{
self
.
wod_metric
.
name
:
(
labels
[
'groundtruths'
],
wod_model_outputs
)})
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
self
.
_task_config
.
use_coco_metrics
:
if
state
is
None
:
self
.
coco_metric
.
reset_states
()
self
.
coco_metric
.
update_state
(
step_outputs
[
self
.
coco_metric
.
name
][
0
],
step_outputs
[
self
.
coco_metric
.
name
][
1
])
if
self
.
_task_config
.
use_wod_metrics
:
if
state
is
None
:
self
.
wod_metric
.
reset_states
()
self
.
wod_metric
.
update_state
(
step_outputs
[
self
.
wod_metric
.
name
][
0
],
step_outputs
[
self
.
wod_metric
.
name
][
1
])
if
state
is
None
:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state
=
True
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
logs
=
{}
if
self
.
_task_config
.
use_coco_metrics
:
logs
.
update
(
self
.
coco_metric
.
result
())
if
self
.
_task_config
.
use_wod_metrics
:
logs
.
update
(
self
.
wod_metric
.
result
())
return
logs
official/vision/beta/tasks/retinanet.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.vision.beta.configs
import
retinanet
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
retinanet_input
from
official.vision.beta.dataloaders
import
tf_example_decoder
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tf_example_label_map_decoder
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.losses
import
focal_loss
from
official.vision.beta.losses
import
loss_utils
from
official.vision.beta.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
RetinaNetTask
)
class
RetinaNetTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
RetinaNet task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def
build_model
(
self
):
"""Build RetinaNet model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_retinanet
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
retinanet_input
.
Parser
(
output_size
=
self
.
task_config
.
model
.
input_size
[:
2
],
min_level
=
self
.
task_config
.
model
.
min_level
,
max_level
=
self
.
task_config
.
model
.
max_level
,
num_scales
=
self
.
task_config
.
model
.
anchor
.
num_scales
,
aspect_ratios
=
self
.
task_config
.
model
.
anchor
.
aspect_ratios
,
anchor_size
=
self
.
task_config
.
model
.
anchor
.
anchor_size
,
dtype
=
params
.
dtype
,
match_threshold
=
params
.
parser
.
match_threshold
,
unmatched_threshold
=
params
.
parser
.
unmatched_threshold
,
aug_type
=
params
.
parser
.
aug_type
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_scale_min
=
params
.
parser
.
aug_scale_min
,
aug_scale_max
=
params
.
parser
.
aug_scale_max
,
skip_crowd_during_training
=
params
.
parser
.
skip_crowd_during_training
,
max_num_instances
=
params
.
parser
.
max_num_instances
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_attribute_loss
(
self
,
attribute_heads
:
List
[
exp_cfg
.
AttributeHead
],
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
box_sample_weight
:
tf
.
Tensor
)
->
float
:
"""Computes attribute loss.
Args:
attribute_heads: a list of attribute head configs.
outputs: RetinaNet model outputs.
labels: RetinaNet labels.
box_sample_weight: normalized bounding box sample weights.
Returns:
Attribute loss of all attribute heads.
"""
attribute_loss
=
0.0
for
head
in
attribute_heads
:
if
head
.
name
not
in
labels
[
'attribute_targets'
]:
raise
ValueError
(
f
'Attribute
{
head
.
name
}
not found in label targets.'
)
if
head
.
name
not
in
outputs
[
'attribute_outputs'
]:
raise
ValueError
(
f
'Attribute
{
head
.
name
}
not found in model outputs.'
)
y_true_att
=
loss_utils
.
multi_level_flatten
(
labels
[
'attribute_targets'
][
head
.
name
],
last_dim
=
head
.
size
)
y_pred_att
=
loss_utils
.
multi_level_flatten
(
outputs
[
'attribute_outputs'
][
head
.
name
],
last_dim
=
head
.
size
)
if
head
.
type
==
'regression'
:
att_loss_fn
=
tf
.
keras
.
losses
.
Huber
(
1.0
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
att_loss
=
att_loss_fn
(
y_true
=
y_true_att
,
y_pred
=
y_pred_att
,
sample_weight
=
box_sample_weight
)
else
:
raise
ValueError
(
f
'Attribute type
{
head
.
type
}
not supported.'
)
attribute_loss
+=
att_loss
return
attribute_loss
def
build_losses
(
self
,
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Build RetinaNet losses."""
params
=
self
.
task_config
attribute_heads
=
self
.
task_config
.
model
.
head
.
attribute_heads
cls_loss_fn
=
focal_loss
.
FocalLoss
(
alpha
=
params
.
losses
.
focal_loss_alpha
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
box_loss_fn
=
tf
.
keras
.
losses
.
Huber
(
params
.
losses
.
huber_loss_delta
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
cls_sample_weight
=
labels
[
'cls_weights'
]
box_sample_weight
=
labels
[
'box_weights'
]
num_positives
=
tf
.
reduce_sum
(
box_sample_weight
)
+
1.0
cls_sample_weight
=
cls_sample_weight
/
num_positives
box_sample_weight
=
box_sample_weight
/
num_positives
y_true_cls
=
loss_utils
.
multi_level_flatten
(
labels
[
'cls_targets'
],
last_dim
=
None
)
y_true_cls
=
tf
.
one_hot
(
y_true_cls
,
params
.
model
.
num_classes
)
y_pred_cls
=
loss_utils
.
multi_level_flatten
(
outputs
[
'cls_outputs'
],
last_dim
=
params
.
model
.
num_classes
)
y_true_box
=
loss_utils
.
multi_level_flatten
(
labels
[
'box_targets'
],
last_dim
=
4
)
y_pred_box
=
loss_utils
.
multi_level_flatten
(
outputs
[
'box_outputs'
],
last_dim
=
4
)
cls_loss
=
cls_loss_fn
(
y_true
=
y_true_cls
,
y_pred
=
y_pred_cls
,
sample_weight
=
cls_sample_weight
)
box_loss
=
box_loss_fn
(
y_true
=
y_true_box
,
y_pred
=
y_pred_box
,
sample_weight
=
box_sample_weight
)
model_loss
=
cls_loss
+
params
.
losses
.
box_loss_weight
*
box_loss
if
attribute_heads
:
model_loss
+=
self
.
build_attribute_loss
(
attribute_heads
,
outputs
,
labels
,
box_sample_weight
)
total_loss
=
model_loss
if
aux_losses
:
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
total_loss
=
model_loss
+
reg_loss
total_loss
=
params
.
losses
.
loss_weight
*
total_loss
return
total_loss
,
cls_loss
,
box_loss
,
model_loss
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Build detection metrics."""
metrics
=
[]
metric_names
=
[
'total_loss'
,
'cls_loss'
,
'box_loss'
,
'model_loss'
]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
if
not
training
:
if
self
.
task_config
.
validation_data
.
tfds_name
and
self
.
task_config
.
annotation_file
:
raise
ValueError
(
"Can't evaluate using annotation file when TFDS is used."
)
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
task_config
.
annotation_file
,
include_mask
=
False
,
per_category_metrics
=
self
.
task_config
.
per_category_metrics
)
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
,
cls_loss
,
box_loss
,
model_loss
=
self
.
build_losses
(
outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient when LossScaleOptimizer is used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
all_losses
=
{
'total_loss'
:
loss
,
'cls_loss'
:
cls_loss
,
'box_loss'
:
box_loss
,
'model_loss'
:
model_loss
,
}
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
all_losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
model
(
features
,
anchor_boxes
=
labels
[
'anchor_boxes'
],
image_shape
=
labels
[
'image_info'
][:,
1
,
:],
training
=
False
)
loss
,
cls_loss
,
box_loss
,
model_loss
=
self
.
build_losses
(
outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
all_losses
=
{
'total_loss'
:
loss
,
'cls_loss'
:
cls_loss
,
'box_loss'
:
box_loss
,
'model_loss'
:
model_loss
,
}
coco_model_outputs
=
{
'detection_boxes'
:
outputs
[
'detection_boxes'
],
'detection_scores'
:
outputs
[
'detection_scores'
],
'detection_classes'
:
outputs
[
'detection_classes'
],
'num_detections'
:
outputs
[
'num_detections'
],
'source_id'
:
labels
[
'groundtruths'
][
'source_id'
],
'image_info'
:
labels
[
'image_info'
]
}
logs
.
update
({
self
.
coco_metric
.
name
:
(
labels
[
'groundtruths'
],
coco_model_outputs
)})
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
all_losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
self
.
coco_metric
.
reset_states
()
state
=
self
.
coco_metric
self
.
coco_metric
.
update_state
(
step_outputs
[
self
.
coco_metric
.
name
][
0
],
step_outputs
[
self
.
coco_metric
.
name
][
1
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
return
self
.
coco_metric
.
result
()
official/vision/beta/tasks/semantic_segmentation.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image segmentation task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
,
Mapping
,
Union
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.vision.beta.configs
import
semantic_segmentation
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
segmentation_input
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.evaluation
import
segmentation_metrics
from
official.vision.beta.losses
import
segmentation_losses
from
official.vision.beta.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
SemanticSegmentationTask
)
class
SemanticSegmentationTask
(
base_task
.
Task
):
"""A task for semantic segmentation."""
def
build_model
(
self
):
"""Builds segmentation model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_segmentation_model
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loads pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds classification input."""
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_segmentation_decoder
(
params
.
tfds_name
)
else
:
decoder
=
segmentation_input
.
Decoder
()
parser
=
segmentation_input
.
Parser
(
output_size
=
params
.
output_size
,
crop_size
=
params
.
crop_size
,
ignore_label
=
ignore_label
,
resize_eval_groundtruth
=
params
.
resize_eval_groundtruth
,
groundtruth_padded_size
=
params
.
groundtruth_padded_size
,
aug_scale_min
=
params
.
aug_scale_min
,
aug_scale_max
=
params
.
aug_scale_max
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
preserve_aspect_ratio
=
params
.
preserve_aspect_ratio
,
dtype
=
params
.
dtype
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
:
Mapping
[
str
,
tf
.
Tensor
],
model_outputs
:
Union
[
Mapping
[
str
,
tf
.
Tensor
],
tf
.
Tensor
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Segmentation loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
loss_params
=
self
.
_task_config
.
losses
segmentation_loss_fn
=
segmentation_losses
.
SegmentationLoss
(
loss_params
.
label_smoothing
,
loss_params
.
class_weights
,
loss_params
.
ignore_label
,
use_groundtruth_dimension
=
loss_params
.
use_groundtruth_dimension
,
top_k_percent_pixels
=
loss_params
.
top_k_percent_pixels
)
total_loss
=
segmentation_loss_fn
(
model_outputs
[
'logits'
],
labels
[
'masks'
])
if
'mask_scores'
in
model_outputs
:
mask_scoring_loss_fn
=
segmentation_losses
.
MaskScoringLoss
(
loss_params
.
ignore_label
)
total_loss
+=
mask_scoring_loss_fn
(
model_outputs
[
'mask_scores'
],
model_outputs
[
'logits'
],
labels
[
'masks'
])
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
total_loss
=
loss_params
.
loss_weight
*
total_loss
return
total_loss
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
,
**
kwargs
):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for
metric
in
metrics
:
if
'mask_scores_mse'
is
metric
.
name
:
actual_mask_scores
=
segmentation_losses
.
get_actual_mask_scores
(
model_outputs
[
'logits'
],
labels
[
'masks'
],
self
.
task_config
.
losses
.
ignore_label
)
metric
.
update_state
(
actual_mask_scores
,
model_outputs
[
'mask_scores'
])
else
:
metric
.
update_state
(
labels
,
model_outputs
[
'logits'
])
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
metrics
=
[]
if
training
and
self
.
task_config
.
evaluation
.
report_train_mean_iou
:
metrics
.
append
(
segmentation_metrics
.
MeanIoU
(
name
=
'mean_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
False
,
dtype
=
tf
.
float32
))
if
self
.
task_config
.
model
.
get
(
'mask_scoring_head'
):
metrics
.
append
(
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
else
:
self
.
iou_metric
=
segmentation_metrics
.
PerClassIoU
(
name
=
'per_class_iou'
,
num_classes
=
self
.
task_config
.
model
.
num_classes
,
rescale_predictions
=
not
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
,
dtype
=
tf
.
float32
)
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
and
self
.
task_config
.
model
.
get
(
'mask_scoring_head'
):
# pylint: disable=line-too-long
# Masks scores metric can only be computed if labels are scaled to match
# preticted mask scores.
metrics
.
append
(
tf
.
keras
.
metrics
.
MeanSquaredError
(
name
=
'mask_scores_mse'
))
# Update state on CPU if TPUStrategy due to dynamic resizing.
self
.
_process_iou_metric_on_cpu
=
isinstance
(
tf
.
distribute
.
get_strategy
(),
tf
.
distribute
.
TPUStrategy
)
return
metrics
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
train_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
if
isinstance
(
outputs
,
tf
.
Tensor
):
outputs
=
{
'logits'
:
outputs
}
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
eval_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
outputs
=
self
.
inference_step
(
features
,
model
)
if
isinstance
(
outputs
,
tf
.
Tensor
):
outputs
=
{
'logits'
:
outputs
}
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
validation_data
.
resize_eval_groundtruth
:
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
else
:
loss
=
0
logs
=
{
self
.
loss
:
loss
}
if
self
.
_process_iou_metric_on_cpu
:
logs
.
update
({
self
.
iou_metric
.
name
:
(
labels
,
outputs
[
'logits'
])})
else
:
self
.
iou_metric
.
update_state
(
labels
,
outputs
[
'logits'
])
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
return
logs
def
inference_step
(
self
,
inputs
:
tf
.
Tensor
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
self
.
iou_metric
.
reset_states
()
state
=
self
.
iou_metric
if
self
.
_process_iou_metric_on_cpu
:
self
.
iou_metric
.
update_state
(
step_outputs
[
self
.
iou_metric
.
name
][
0
],
step_outputs
[
self
.
iou_metric
.
name
][
1
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
result
=
{}
ious
=
self
.
iou_metric
.
result
()
# TODO(arashwan): support loading class name from a label map file.
if
self
.
task_config
.
evaluation
.
report_per_class_iou
:
for
i
,
value
in
enumerate
(
ious
.
numpy
()):
result
.
update
({
'iou/{}'
.
format
(
i
):
value
})
# Computes mean IoU
result
.
update
({
'mean_iou'
:
tf
.
reduce_mean
(
ious
).
numpy
()})
return
result
official/vision/beta/tasks/video_classification.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video classification task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.vision.beta.configs
import
video_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.modeling
import
factory_3d
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoClassificationTask
)
class
VideoClassificationTask
(
base_task
.
Task
):
"""A task for video classification."""
def
_get_num_classes
(
self
):
"""Gets the number of classes."""
return
self
.
task_config
.
train_data
.
num_classes
def
_get_feature_shape
(
self
):
"""Get the common feature shape for train and eval."""
return
[
d1
if
d1
==
d2
else
None
for
d1
,
d2
in
zip
(
self
.
task_config
.
train_data
.
feature_shape
,
self
.
task_config
.
validation_data
.
feature_shape
)
]
def
_get_num_test_views
(
self
):
"""Gets number of views for test."""
num_test_clips
=
self
.
task_config
.
validation_data
.
num_test_clips
num_test_crops
=
self
.
task_config
.
validation_data
.
num_test_crops
num_test_views
=
num_test_clips
*
num_test_crops
return
num_test_views
def
_is_multilabel
(
self
):
"""If the label is multi-labels."""
return
self
.
task_config
.
train_data
.
is_multilabel
def
build_model
(
self
):
"""Builds video classification model."""
common_input_shape
=
self
.
_get_feature_shape
()
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
common_input_shape
)
logging
.
info
(
'Build model input %r'
,
common_input_shape
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory_3d
.
build_model
(
self
.
task_config
.
model
.
model_type
,
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
num_classes
=
self
.
_get_num_classes
(),
l2_regularizer
=
l2_regularizer
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loads pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
_get_dataset_fn
(
self
,
params
):
if
params
.
file_type
==
'tfrecord'
:
return
tf
.
data
.
TFRecordDataset
else
:
raise
ValueError
(
'Unknown input file type {!r}'
.
format
(
params
.
file_type
))
def
_get_decoder_fn
(
self
,
params
):
if
params
.
tfds_name
:
decoder
=
video_input
.
VideoTfdsDecoder
(
image_key
=
params
.
image_field_key
,
label_key
=
params
.
label_field_key
)
else
:
decoder
=
video_input
.
Decoder
(
image_key
=
params
.
image_field_key
,
label_key
=
params
.
label_field_key
)
if
self
.
task_config
.
train_data
.
output_audio
:
assert
self
.
task_config
.
train_data
.
audio_feature
,
'audio feature is empty'
decoder
.
add_feature
(
self
.
task_config
.
train_data
.
audio_feature
,
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
))
return
decoder
.
decode
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds classification input."""
parser
=
video_input
.
Parser
(
input_params
=
params
,
image_key
=
params
.
image_field_key
,
label_key
=
params
.
label_field_key
)
postprocess_fn
=
video_input
.
PostBatchProcessor
(
params
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
self
.
_get_dataset_fn
(
params
),
decoder_fn
=
self
.
_get_decoder_fn
(
params
),
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
:
Any
,
model_outputs
:
Any
,
aux_losses
:
Optional
[
Any
]
=
None
):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
all_losses
=
{}
losses_config
=
self
.
task_config
.
losses
total_loss
=
None
if
self
.
_is_multilabel
():
entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
model_outputs
*
tf
.
math
.
log
(
model_outputs
+
1e-8
),
-
1
))
total_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
)
all_losses
.
update
({
'class_loss'
:
total_loss
,
'entropy'
:
entropy
,
})
else
:
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
,
label_smoothing
=
losses_config
.
label_smoothing
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
False
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
all_losses
.
update
({
'class_loss'
:
total_loss
,
})
if
aux_losses
:
all_losses
.
update
({
'reg_loss'
:
aux_losses
,
})
total_loss
+=
tf
.
add_n
(
aux_losses
)
all_losses
[
self
.
loss
]
=
total_loss
return
all_losses
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
if
self
.
task_config
.
losses
.
one_hot
:
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
if
self
.
_is_multilabel
():
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'ROC'
,
multi_label
=
self
.
_is_multilabel
(),
name
=
'ROC-AUC'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
RecallAtPrecision
(
0.95
,
name
=
'RecallAtPrecision95'
))
metrics
.
append
(
tf
.
keras
.
metrics
.
AUC
(
curve
=
'PR'
,
multi_label
=
self
.
_is_multilabel
(),
name
=
'PR-AUC'
))
if
self
.
task_config
.
metrics
.
use_per_class_recall
:
for
i
in
range
(
self
.
_get_num_classes
()):
metrics
.
append
(
tf
.
keras
.
metrics
.
Recall
(
class_id
=
i
,
name
=
f
'recall-
{
i
}
'
))
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
return
metrics
def
process_metrics
(
self
,
metrics
:
List
[
Any
],
labels
:
Any
,
model_outputs
:
Any
):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
train_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
[
'image'
]
=
strategy
.
experimental_split_to_logical_devices
(
features
[
'image'
],
input_partition_dims
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
if
self
.
_is_multilabel
():
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
all_losses
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
=
all_losses
[
self
.
loss
]
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
all_losses
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
eval_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
[
'image'
]
=
strategy
.
experimental_split_to_logical_devices
(
features
[
'image'
],
input_partition_dims
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
logs
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
features
:
tf
.
Tensor
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step."""
outputs
=
model
(
features
,
training
=
False
)
if
self
.
_is_multilabel
():
outputs
=
tf
.
math
.
sigmoid
(
outputs
)
else
:
outputs
=
tf
.
math
.
softmax
(
outputs
)
num_test_views
=
self
.
_get_num_test_views
()
if
num_test_views
>
1
:
# Averaging output probabilities across multiples views.
outputs
=
tf
.
reshape
(
outputs
,
[
-
1
,
num_test_views
,
outputs
.
shape
[
-
1
]])
outputs
=
tf
.
reduce_mean
(
outputs
,
axis
=
1
)
return
outputs
official/vision/beta/train.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/vision/beta/train_spatial_partitioning.py
deleted
100644 → 0
View file @
f47405b5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver with spatial partitioning."""
from
typing
import
Sequence
from
absl
import
app
from
absl
import
flags
import
gin
import
numpy
as
np
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
FLAGS
=
flags
.
FLAGS
def
get_computation_shape_for_model_parallelism
(
input_partition_dims
:
Sequence
[
int
])
->
Sequence
[
int
]:
"""Returns computation shape to be used for TPUStrategy spatial partition.
Args:
input_partition_dims: The number of partitions along each dimension.
Returns:
A list of integers specifying the computation shape.
Raises:
ValueError: If the number of logical devices is not supported.
"""
num_logical_devices
=
np
.
prod
(
input_partition_dims
)
if
num_logical_devices
==
1
:
return
[
1
,
1
,
1
,
1
]
elif
num_logical_devices
==
2
:
return
[
1
,
1
,
1
,
2
]
elif
num_logical_devices
==
4
:
return
[
1
,
2
,
1
,
2
]
elif
num_logical_devices
==
8
:
return
[
2
,
2
,
1
,
2
]
elif
num_logical_devices
==
16
:
return
[
4
,
2
,
1
,
2
]
else
:
raise
ValueError
(
'The number of logical devices %d is not supported. Supported numbers '
'are 1, 2, 4, 8, 16'
%
num_logical_devices
)
def
create_distribution_strategy
(
distribution_strategy
,
tpu_address
,
input_partition_dims
=
None
,
num_gpus
=
None
):
"""Creates distribution strategy to use for computation."""
if
input_partition_dims
is
not
None
:
if
distribution_strategy
!=
'tpu'
:
raise
ValueError
(
'Spatial partitioning is only supported '
'for TPUStrategy.'
)
# When `input_partition_dims` is specified create custom TPUStrategy
# instance with computation shape for model parallelism.
resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
tpu
=
tpu_address
)
if
tpu_address
not
in
(
''
,
'local'
):
tf
.
config
.
experimental_connect_to_cluster
(
resolver
)
topology
=
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
resolver
)
num_replicas
=
resolver
.
get_tpu_system_metadata
().
num_cores
//
np
.
prod
(
input_partition_dims
)
device_assignment
=
tf
.
tpu
.
experimental
.
DeviceAssignment
.
build
(
topology
,
num_replicas
=
num_replicas
,
computation_shape
=
input_partition_dims
)
return
tf
.
distribute
.
TPUStrategy
(
resolver
,
experimental_device_assignment
=
device_assignment
)
return
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribution_strategy
,
tpu_address
=
tpu_address
,
num_gpus
=
num_gpus
)
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
input_partition_dims
=
None
if
FLAGS
.
mode
==
'train_and_eval'
:
if
np
.
prod
(
params
.
task
.
train_input_partition_dims
)
!=
np
.
prod
(
params
.
task
.
eval_input_partition_dims
):
raise
ValueError
(
'Train and eval input partition dims can not be'
'partitioned on the same node'
)
else
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
train_input_partition_dims
)
elif
FLAGS
.
mode
==
'train'
:
if
params
.
task
.
train_input_partition_dims
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
train_input_partition_dims
)
elif
FLAGS
.
mode
==
'eval'
or
FLAGS
.
mode
==
'continuous_eval'
:
if
params
.
task
.
eval_input_partition_dims
:
input_partition_dims
=
get_computation_shape_for_model_parallelism
(
params
.
task
.
eval_input_partition_dims
)
distribution_strategy
=
create_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
num_gpus
=
params
.
runtime
.
num_gpus
,
input_partition_dims
=
input_partition_dims
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment