Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b0ccdb11
Commit
b0ccdb11
authored
Sep 28, 2020
by
Shixin Luo
Browse files
resolve conflict with master
parents
e61588cd
1611a8c5
Changes
210
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
737 additions
and
582 deletions
+737
-582
official/vision/beta/ops/nms_test.py
official/vision/beta/ops/nms_test.py
+0
-104
official/vision/beta/ops/spatial_transform_ops.py
official/vision/beta/ops/spatial_transform_ops.py
+6
-6
official/vision/beta/ops/spatial_transform_ops_test.py
official/vision/beta/ops/spatial_transform_ops_test.py
+0
-287
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+94
-0
official/vision/beta/serving/export_saved_model.py
official/vision/beta/serving/export_saved_model.py
+181
-0
official/vision/beta/serving/image_classification.py
official/vision/beta/serving/image_classification.py
+83
-0
official/vision/beta/serving/image_classification_test.py
official/vision/beta/serving/image_classification_test.py
+126
-0
official/vision/beta/tasks/__init__.py
official/vision/beta/tasks/__init__.py
+1
-0
official/vision/beta/tasks/maskrcnn_test.py
official/vision/beta/tasks/maskrcnn_test.py
+0
-70
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+5
-5
official/vision/beta/tasks/retinanet_test.py
official/vision/beta/tasks/retinanet_test.py
+0
-65
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+201
-0
official/vision/beta/train.py
official/vision/beta/train.py
+2
-2
official/vision/detection/dataloader/anchor.py
official/vision/detection/dataloader/anchor.py
+2
-2
official/vision/detection/main.py
official/vision/detection/main.py
+5
-14
official/vision/detection/utils/object_detection/target_assigner.py
...ision/detection/utils/object_detection/target_assigner.py
+2
-2
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+16
-10
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+7
-7
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+3
-4
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+3
-4
No files found.
official/vision/beta/ops/nms_test.py
deleted
100644 → 0
View file @
e61588cd
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nms.py."""
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.ops
import
nms
class
SortedNonMaxSuppressionTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
SortedNonMaxSuppressionTest
,
self
).
setUp
()
self
.
boxes_data
=
[[[
0
,
0
,
1
,
1
],
[
0
,
0.2
,
1
,
1.2
],
[
0
,
0.4
,
1
,
1.4
],
[
0
,
0.6
,
1
,
1.6
],
[
0
,
0.8
,
1
,
1.8
],
[
0
,
2
,
1
,
2
]],
[[
0
,
2
,
1
,
2
],
[
0
,
0.8
,
1
,
1.8
],
[
0
,
0.6
,
1
,
1.6
],
[
0
,
0.4
,
1
,
1.4
],
[
0
,
0.2
,
1
,
1.2
],
[
0
,
0
,
1
,
1
]]]
self
.
scores_data
=
[[
0.9
,
0.7
,
0.6
,
0.5
,
0.4
,
0.3
],
[
0.8
,
0.7
,
0.6
,
0.5
,
0.4
,
0.3
]]
self
.
max_output_size
=
6
self
.
iou_threshold
=
0.5
def
testSortedNonMaxSuppressionOnTPU
(
self
):
boxes_np
=
np
.
array
(
self
.
boxes_data
,
dtype
=
np
.
float32
)
scores_np
=
np
.
array
(
self
.
scores_data
,
dtype
=
np
.
float32
)
iou_threshold_np
=
np
.
array
(
self
.
iou_threshold
,
dtype
=
np
.
float32
)
boxes
=
tf
.
constant
(
boxes_np
)
scores
=
tf
.
constant
(
scores_np
)
iou_threshold
=
tf
.
constant
(
iou_threshold_np
)
# Runs on TPU.
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
()
with
strategy
.
scope
():
scores_tpu
,
boxes_tpu
=
nms
.
sorted_non_max_suppression_padded
(
boxes
=
boxes
,
scores
=
scores
,
max_output_size
=
self
.
max_output_size
,
iou_threshold
=
iou_threshold
)
self
.
assertEqual
(
boxes_tpu
.
numpy
().
shape
,
(
2
,
self
.
max_output_size
,
4
))
self
.
assertAllClose
(
scores_tpu
.
numpy
(),
[[
0.9
,
0.6
,
0.4
,
0.3
,
0.
,
0.
],
[
0.8
,
0.7
,
0.5
,
0.3
,
0.
,
0.
]])
def
testSortedNonMaxSuppressionOnCPU
(
self
):
boxes_np
=
np
.
array
(
self
.
boxes_data
,
dtype
=
np
.
float32
)
scores_np
=
np
.
array
(
self
.
scores_data
,
dtype
=
np
.
float32
)
iou_threshold_np
=
np
.
array
(
self
.
iou_threshold
,
dtype
=
np
.
float32
)
boxes
=
tf
.
constant
(
boxes_np
)
scores
=
tf
.
constant
(
scores_np
)
iou_threshold
=
tf
.
constant
(
iou_threshold_np
)
# Runs on CPU.
scores_cpu
,
boxes_cpu
=
nms
.
sorted_non_max_suppression_padded
(
boxes
=
boxes
,
scores
=
scores
,
max_output_size
=
self
.
max_output_size
,
iou_threshold
=
iou_threshold
)
self
.
assertEqual
(
boxes_cpu
.
numpy
().
shape
,
(
2
,
self
.
max_output_size
,
4
))
self
.
assertAllClose
(
scores_cpu
.
numpy
(),
[[
0.9
,
0.6
,
0.4
,
0.3
,
0.
,
0.
],
[
0.8
,
0.7
,
0.5
,
0.3
,
0.
,
0.
]])
def
testSortedNonMaxSuppressionOnTPUSpeed
(
self
):
boxes_np
=
np
.
random
.
rand
(
2
,
12000
,
4
).
astype
(
np
.
float32
)
scores_np
=
np
.
random
.
rand
(
2
,
12000
).
astype
(
np
.
float32
)
iou_threshold_np
=
np
.
array
(
0.7
,
dtype
=
np
.
float32
)
boxes
=
tf
.
constant
(
boxes_np
)
scores
=
tf
.
constant
(
scores_np
)
iou_threshold
=
tf
.
constant
(
iou_threshold_np
)
# Runs on TPU.
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
()
with
strategy
.
scope
():
scores_tpu
,
boxes_tpu
=
nms
.
sorted_non_max_suppression_padded
(
boxes
=
boxes
,
scores
=
scores
,
max_output_size
=
2000
,
iou_threshold
=
iou_threshold
)
self
.
assertEqual
(
scores_tpu
.
numpy
().
shape
,
(
2
,
2000
))
self
.
assertEqual
(
boxes_tpu
.
numpy
().
shape
,
(
2
,
2000
,
4
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/ops/spatial_transform_ops.py
View file @
b0ccdb11
...
@@ -159,12 +159,12 @@ def multilevel_crop_and_resize(features,
...
@@ -159,12 +159,12 @@ def multilevel_crop_and_resize(features,
with
tf
.
name_scope
(
'multilevel_crop_and_resize'
):
with
tf
.
name_scope
(
'multilevel_crop_and_resize'
):
levels
=
list
(
features
.
keys
())
levels
=
list
(
features
.
keys
())
min_level
=
min
(
levels
)
min_level
=
int
(
min
(
levels
)
)
max_level
=
max
(
levels
)
max_level
=
int
(
max
(
levels
)
)
batch_size
,
max_feature_height
,
max_feature_width
,
num_filters
=
(
batch_size
,
max_feature_height
,
max_feature_width
,
num_filters
=
(
features
[
min_level
].
get_shape
().
as_list
())
features
[
str
(
min_level
)
].
get_shape
().
as_list
())
if
batch_size
is
None
:
if
batch_size
is
None
:
batch_size
=
tf
.
shape
(
features
[
min_level
])[
0
]
batch_size
=
tf
.
shape
(
features
[
str
(
min_level
)
])[
0
]
_
,
num_boxes
,
_
=
boxes
.
get_shape
().
as_list
()
_
,
num_boxes
,
_
=
boxes
.
get_shape
().
as_list
()
# Stack feature pyramid into a features_all of shape
# Stack feature pyramid into a features_all of shape
...
@@ -173,13 +173,13 @@ def multilevel_crop_and_resize(features,
...
@@ -173,13 +173,13 @@ def multilevel_crop_and_resize(features,
feature_heights
=
[]
feature_heights
=
[]
feature_widths
=
[]
feature_widths
=
[]
for
level
in
range
(
min_level
,
max_level
+
1
):
for
level
in
range
(
min_level
,
max_level
+
1
):
shape
=
features
[
level
].
get_shape
().
as_list
()
shape
=
features
[
str
(
level
)
].
get_shape
().
as_list
()
feature_heights
.
append
(
shape
[
1
])
feature_heights
.
append
(
shape
[
1
])
feature_widths
.
append
(
shape
[
2
])
feature_widths
.
append
(
shape
[
2
])
# Concat tensor of [batch_size, height_l * width_l, num_filters] for each
# Concat tensor of [batch_size, height_l * width_l, num_filters] for each
# levels.
# levels.
features_all
.
append
(
features_all
.
append
(
tf
.
reshape
(
features
[
level
],
[
batch_size
,
-
1
,
num_filters
]))
tf
.
reshape
(
features
[
str
(
level
)
],
[
batch_size
,
-
1
,
num_filters
]))
features_r2
=
tf
.
reshape
(
tf
.
concat
(
features_all
,
1
),
[
-
1
,
num_filters
])
features_r2
=
tf
.
reshape
(
tf
.
concat
(
features_all
,
1
),
[
-
1
,
num_filters
])
# Calculate height_l * width_l for each level.
# Calculate height_l * width_l for each level.
...
...
official/vision/beta/ops/spatial_transform_ops_test.py
deleted
100644 → 0
View file @
e61588cd
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for spatial_transform_ops.py."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.ops
import
spatial_transform_ops
class
MultiLevelCropAndResizeTest
(
tf
.
test
.
TestCase
):
def
test_multilevel_crop_and_resize_square
(
self
):
"""Example test case.
Input =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]
output_size = 2x2
box =
[
[[0, 0, 2, 2]]
]
Gathered data =
[
[0, 1, 1, 2],
[4, 5, 5, 6],
[4, 5, 5, 6],
[8, 9, 9, 10],
]
Interpolation kernel =
[
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
]
Output =
[
[2.5, 3.5],
[6.5, 7.5]
]
"""
input_size
=
4
min_level
=
0
max_level
=
0
batch_size
=
1
output_size
=
2
num_filters
=
1
features
=
{}
for
level
in
range
(
min_level
,
max_level
+
1
):
feat_size
=
int
(
input_size
/
2
**
level
)
features
[
level
]
=
tf
.
range
(
batch_size
*
feat_size
*
feat_size
*
num_filters
,
dtype
=
tf
.
float32
)
features
[
level
]
=
tf
.
reshape
(
features
[
level
],
[
batch_size
,
feat_size
,
feat_size
,
num_filters
])
boxes
=
tf
.
constant
([
[[
0
,
0
,
2
,
2
]],
],
dtype
=
tf
.
float32
)
tf_roi_features
=
spatial_transform_ops
.
multilevel_crop_and_resize
(
features
,
boxes
,
output_size
)
roi_features
=
tf_roi_features
.
numpy
()
self
.
assertAllClose
(
roi_features
,
np
.
array
([[
2.5
,
3.5
],
[
6.5
,
7.5
]]).
reshape
([
batch_size
,
1
,
output_size
,
output_size
,
1
]))
def
test_multilevel_crop_and_resize_rectangle
(
self
):
"""Example test case.
Input =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]
output_size = 2x2
box =
[
[[0, 0, 2, 3]]
]
Box vertices =
[
[[0.5, 0.75], [0.5, 2.25]],
[[1.5, 0.75], [1.5, 2.25]],
]
Gathered data =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[4, 5, 6, 7],
[8, 9, 10, 11],
]
Interpolation kernel =
[
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
]
Output =
[
[2.75, 4.25],
[6.75, 8.25]
]
"""
input_size
=
4
min_level
=
0
max_level
=
0
batch_size
=
1
output_size
=
2
num_filters
=
1
features
=
{}
for
level
in
range
(
min_level
,
max_level
+
1
):
feat_size
=
int
(
input_size
/
2
**
level
)
features
[
level
]
=
tf
.
range
(
batch_size
*
feat_size
*
feat_size
*
num_filters
,
dtype
=
tf
.
float32
)
features
[
level
]
=
tf
.
reshape
(
features
[
level
],
[
batch_size
,
feat_size
,
feat_size
,
num_filters
])
boxes
=
tf
.
constant
([
[[
0
,
0
,
2
,
3
]],
],
dtype
=
tf
.
float32
)
tf_roi_features
=
spatial_transform_ops
.
multilevel_crop_and_resize
(
features
,
boxes
,
output_size
)
roi_features
=
tf_roi_features
.
numpy
()
self
.
assertAllClose
(
roi_features
,
np
.
array
([[
2.75
,
4.25
],
[
6.75
,
8.25
]]).
reshape
([
batch_size
,
1
,
output_size
,
output_size
,
1
]))
def
test_multilevel_crop_and_resize_two_boxes
(
self
):
"""Test two boxes."""
input_size
=
4
min_level
=
0
max_level
=
0
batch_size
=
1
output_size
=
2
num_filters
=
1
features
=
{}
for
level
in
range
(
min_level
,
max_level
+
1
):
feat_size
=
int
(
input_size
/
2
**
level
)
features
[
level
]
=
tf
.
range
(
batch_size
*
feat_size
*
feat_size
*
num_filters
,
dtype
=
tf
.
float32
)
features
[
level
]
=
tf
.
reshape
(
features
[
level
],
[
batch_size
,
feat_size
,
feat_size
,
num_filters
])
boxes
=
tf
.
constant
([
[[
0
,
0
,
2
,
2
],
[
0
,
0
,
2
,
3
]],
],
dtype
=
tf
.
float32
)
tf_roi_features
=
spatial_transform_ops
.
multilevel_crop_and_resize
(
features
,
boxes
,
output_size
)
roi_features
=
tf_roi_features
.
numpy
()
self
.
assertAllClose
(
roi_features
,
np
.
array
([[[
2.5
,
3.5
],
[
6.5
,
7.5
]],
[[
2.75
,
4.25
],
[
6.75
,
8.25
]]
]).
reshape
([
batch_size
,
2
,
output_size
,
output_size
,
1
]))
def
test_multilevel_crop_and_resize_feature_level_assignment
(
self
):
"""Test feature level assignment."""
input_size
=
640
min_level
=
2
max_level
=
5
batch_size
=
1
output_size
=
2
num_filters
=
1
features
=
{}
for
level
in
range
(
min_level
,
max_level
+
1
):
feat_size
=
int
(
input_size
/
2
**
level
)
features
[
level
]
=
float
(
level
)
*
tf
.
ones
(
[
batch_size
,
feat_size
,
feat_size
,
num_filters
],
dtype
=
tf
.
float32
)
boxes
=
tf
.
constant
(
[
[
[
0
,
0
,
111
,
111
],
# Level 2.
[
0
,
0
,
113
,
113
],
# Level 3.
[
0
,
0
,
223
,
223
],
# Level 3.
[
0
,
0
,
225
,
225
],
# Level 4.
[
0
,
0
,
449
,
449
]
],
# Level 5.
],
dtype
=
tf
.
float32
)
tf_roi_features
=
spatial_transform_ops
.
multilevel_crop_and_resize
(
features
,
boxes
,
output_size
)
roi_features
=
tf_roi_features
.
numpy
()
self
.
assertAllClose
(
roi_features
[
0
,
0
],
2
*
np
.
ones
((
2
,
2
,
1
)))
self
.
assertAllClose
(
roi_features
[
0
,
1
],
3
*
np
.
ones
((
2
,
2
,
1
)))
self
.
assertAllClose
(
roi_features
[
0
,
2
],
3
*
np
.
ones
((
2
,
2
,
1
)))
self
.
assertAllClose
(
roi_features
[
0
,
3
],
4
*
np
.
ones
((
2
,
2
,
1
)))
self
.
assertAllClose
(
roi_features
[
0
,
4
],
5
*
np
.
ones
((
2
,
2
,
1
)))
def
test_multilevel_crop_and_resize_large_input
(
self
):
"""Test 512 boxes on TPU."""
input_size
=
1408
min_level
=
2
max_level
=
6
batch_size
=
2
num_boxes
=
512
num_filters
=
256
output_size
=
7
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
()
with
strategy
.
scope
():
features
=
{}
for
level
in
range
(
min_level
,
max_level
+
1
):
feat_size
=
int
(
input_size
/
2
**
level
)
features
[
level
]
=
tf
.
constant
(
np
.
reshape
(
np
.
arange
(
batch_size
*
feat_size
*
feat_size
*
num_filters
,
dtype
=
np
.
float32
),
[
batch_size
,
feat_size
,
feat_size
,
num_filters
]),
dtype
=
tf
.
bfloat16
)
boxes
=
np
.
array
([
[[
0
,
0
,
256
,
256
]]
*
num_boxes
,
],
dtype
=
np
.
float32
)
boxes
=
np
.
tile
(
boxes
,
[
batch_size
,
1
,
1
])
tf_boxes
=
tf
.
constant
(
boxes
,
dtype
=
tf
.
float32
)
tf_roi_features
=
spatial_transform_ops
.
multilevel_crop_and_resize
(
features
,
tf_boxes
)
roi_features
=
tf_roi_features
.
numpy
()
self
.
assertEqual
(
roi_features
.
shape
,
(
batch_size
,
num_boxes
,
output_size
,
output_size
,
num_filters
))
class
CropMaskInTargetBoxTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
False
),
(
True
),
)
def
test_crop_mask_in_target_box
(
self
,
use_einsum
):
batch_size
=
1
num_masks
=
2
height
=
2
width
=
2
output_size
=
2
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
()
with
strategy
.
scope
():
masks
=
tf
.
ones
([
batch_size
,
num_masks
,
height
,
width
])
boxes
=
tf
.
constant
(
[[
0.
,
0.
,
1.
,
1.
],
[
0.
,
0.
,
1.
,
1.
]])
target_boxes
=
tf
.
constant
(
[[
0.
,
0.
,
1.
,
1.
],
[
-
1.
,
-
1.
,
1.
,
1.
]])
expected_outputs
=
np
.
array
([
[[[
1.
,
1.
],
[
1.
,
1.
]],
[[
0.
,
0.
],
[
0.
,
1.
]]]])
boxes
=
tf
.
reshape
(
boxes
,
[
batch_size
,
num_masks
,
4
])
target_boxes
=
tf
.
reshape
(
target_boxes
,
[
batch_size
,
num_masks
,
4
])
tf_cropped_masks
=
spatial_transform_ops
.
crop_mask_in_target_box
(
masks
,
boxes
,
target_boxes
,
output_size
,
use_einsum
=
use_einsum
)
cropped_masks
=
tf_cropped_masks
.
numpy
()
self
.
assertAllEqual
(
cropped_masks
,
expected_outputs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/serving/export_base.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for model export."""
import
abc
import
tensorflow
as
tf
def
_decode_image
(
encoded_image_bytes
):
image_tensor
=
tf
.
image
.
decode_image
(
encoded_image_bytes
,
channels
=
3
)
image_tensor
.
set_shape
((
None
,
None
,
3
))
return
image_tensor
def
_decode_tf_example
(
tf_example_string_tensor
):
keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
)}
parsed_tensors
=
tf
.
io
.
parse_single_example
(
serialized
=
tf_example_string_tensor
,
features
=
keys_to_features
)
image_tensor
=
_decode_image
(
parsed_tensors
[
'image/encoded'
])
return
image_tensor
class
ExportModule
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""Base Export Module."""
def
__init__
(
self
,
params
,
batch_size
,
input_image_size
,
model
=
None
):
"""Initializes a module for export.
Args:
params: Experiment params.
batch_size: Int or None.
input_image_size: List or Tuple of height, width of the input image.
model: A tf.keras.Model instance to be exported.
"""
super
(
ExportModule
,
self
).
__init__
()
self
.
_params
=
params
self
.
_batch_size
=
batch_size
self
.
_input_image_size
=
input_image_size
self
.
_model
=
model
@
abc
.
abstractmethod
def
build_model
(
self
):
"""Builds model and sets self._model."""
@
abc
.
abstractmethod
def
_run_inference_on_image_tensors
(
self
,
images
):
"""Runs inference on images."""
@
tf
.
function
def
inference_from_image_tensors
(
self
,
input_tensor
):
return
dict
(
outputs
=
self
.
_run_inference_on_image_tensors
(
input_tensor
))
@
tf
.
function
def
inference_from_image_bytes
(
self
,
input_tensor
):
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
_decode_image
,
elems
=
input_tensor
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
),
parallel_iterations
=
32
))
images
=
tf
.
stack
(
images
)
return
dict
(
outputs
=
self
.
_run_inference_on_image_tensors
(
images
))
@
tf
.
function
def
inference_from_tf_example
(
self
,
input_tensor
):
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
_decode_tf_example
,
elems
=
input_tensor
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
uint8
),
dtype
=
tf
.
uint8
,
parallel_iterations
=
32
))
images
=
tf
.
stack
(
images
)
return
dict
(
outputs
=
self
.
_run_inference_on_image_tensors
(
images
))
official/vision/beta/serving/export_saved_model.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = .signatures['serving_default']
output = model_fn(input_images)
"""
import
os
from
absl
import
app
from
absl
import
flags
import
tensorflow.compat.v2
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
train_utils
from
official.modeling
import
hyperparams
from
official.vision.beta
import
configs
from
official.vision.beta.serving
import
image_classification
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
)
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'The export directory.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_type'
,
'image_tensor'
,
'One of `image_tensor`, `image_bytes`, `tf_example`.'
)
flags
.
DEFINE_string
(
'input_image_size'
,
'224,224'
,
'The comma-separated string of two integers representing the height,width '
'of the input to the model.'
)
def
export_inference_graph
(
input_type
,
batch_size
,
input_image_size
,
params
,
checkpoint_path
,
export_dir
):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
"""
output_checkpoint_directory
=
os
.
path
.
join
(
export_dir
,
'checkpoint'
)
output_saved_model_directory
=
os
.
path
.
join
(
export_dir
,
'saved_model'
)
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
export_module
=
image_classification
.
ClassificationModule
(
params
=
params
,
batch_size
=
batch_size
,
input_image_size
=
input_image_size
)
else
:
raise
ValueError
(
'Export module not implemented for {} task.'
.
format
(
type
(
params
.
task
)))
model
=
export_module
.
build_model
()
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
)
ckpt_dir_or_file
=
checkpoint_path
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
).
expect_partial
()
if
input_type
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
,
input_image_size
[
0
],
input_image_size
[
1
],
3
],
dtype
=
tf
.
uint8
)
signatures
=
{
'serving_default'
:
export_module
.
inference_from_image
.
get_concrete_function
(
input_signature
)
}
elif
input_type
==
'image_bytes'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
],
dtype
=
tf
.
string
)
signatures
=
{
'serving_default'
:
export_module
.
inference_from_image_bytes
.
get_concrete_function
(
input_signature
)
}
elif
input_type
==
'tf_example'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
batch_size
],
dtype
=
tf
.
string
)
signatures
=
{
'serving_default'
:
export_module
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
}
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
status
.
assert_existing_objects_matched
()
ckpt
.
save
(
os
.
path
.
join
(
output_checkpoint_directory
,
'ckpt'
))
tf
.
saved_model
.
save
(
export_module
,
output_saved_model_directory
,
signatures
=
signatures
)
train_utils
.
serialize_config
(
params
,
export_dir
)
def
main
(
_
):
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
export_inference_graph
(
input_type
=
FLAGS
.
input_type
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
[
int
(
x
)
for
x
in
FLAGS
.
input_image_size
.
split
(
','
)],
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/serving/image_classification.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection input and model functions for serving/inference."""
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.serving
import
export_base
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
class
ClassificationModule
(
export_base
.
ExportModule
):
"""classification Module."""
def
build_model
(
self
):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
3
])
self
.
_model
=
factory
.
build_classification_model
(
input_specs
=
input_specs
,
model_config
=
self
.
_params
.
task
.
model
,
l2_regularizer
=
None
)
return
self
.
_model
def
_build_inputs
(
self
,
image
):
"""Builds classification model inputs for serving."""
# Center crops and resizes image.
image
=
preprocess_ops
.
center_crop_image
(
image
)
image
=
tf
.
image
.
resize
(
image
,
self
.
_input_image_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
reshape
(
image
,
[
self
.
_input_image_size
[
0
],
self
.
_input_image_size
[
1
],
3
])
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
return
image
def
_run_inference_on_image_tensors
(
self
,
images
):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
images
=
tf
.
nest
.
map_structure
(
tf
.
identity
,
tf
.
map_fn
(
self
.
_build_inputs
,
elems
=
images
,
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
self
.
_input_image_size
+
[
3
],
dtype
=
tf
.
float32
),
parallel_iterations
=
32
)
)
logits
=
self
.
_model
(
images
,
training
=
False
)
return
logits
official/vision/beta/serving/image_classification_test.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for image classification export lib."""
import
io
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.serving
import
image_classification
class
ImageClassificationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_get_classification_module
(
self
):
params
=
exp_factory
.
get_exp_config
(
'resnet_imagenet'
)
params
.
task
.
model
.
backbone
.
resnet
.
model_id
=
18
classification_module
=
image_classification
.
ClassificationModule
(
params
,
batch_size
=
1
,
input_image_size
=
[
224
,
224
])
return
classification_module
def
_export_from_module
(
self
,
module
,
input_type
,
save_directory
):
if
input_type
==
'image_tensor'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
None
,
224
,
224
,
3
],
dtype
=
tf
.
uint8
)
signatures
=
{
'serving_default'
:
module
.
inference_from_image_tensors
.
get_concrete_function
(
input_signature
)
}
elif
input_type
==
'image_bytes'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
)
signatures
=
{
'serving_default'
:
module
.
inference_from_image_bytes
.
get_concrete_function
(
input_signature
)
}
elif
input_type
==
'tf_example'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
string
)
signatures
=
{
'serving_default'
:
module
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
}
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
return
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
np
.
uint8
)
elif
input_type
==
'image_bytes'
:
image
=
Image
.
fromarray
(
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
))
byte_io
=
io
.
BytesIO
()
image
.
save
(
byte_io
,
'PNG'
)
return
[
byte_io
.
getvalue
()]
elif
input_type
==
'tf_example'
:
image_tensor
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
numpy
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_jpeg
])),
})).
SerializeToString
()
return
[
example
]
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_classification_module
()
model
=
module
.
build_model
()
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
classification_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
processed_images
=
tf
.
nest
.
map_structure
(
tf
.
stop_gradient
,
tf
.
map_fn
(
module
.
_build_inputs
,
elems
=
tf
.
zeros
((
1
,
224
,
224
,
3
),
dtype
=
tf
.
uint8
),
fn_output_signature
=
tf
.
TensorSpec
(
shape
=
[
224
,
224
,
3
],
dtype
=
tf
.
float32
)))
expected_output
=
model
(
processed_images
,
training
=
False
)
out
=
classification_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
out
[
'outputs'
].
numpy
(),
expected_output
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/__init__.py
View file @
b0ccdb11
...
@@ -18,3 +18,4 @@
...
@@ -18,3 +18,4 @@
from
official.vision.beta.tasks
import
image_classification
from
official.vision.beta.tasks
import
image_classification
from
official.vision.beta.tasks
import
maskrcnn
from
official.vision.beta.tasks
import
maskrcnn
from
official.vision.beta.tasks
import
retinanet
from
official.vision.beta.tasks
import
retinanet
from
official.vision.beta.tasks
import
video_classification
official/vision/beta/tasks/maskrcnn_test.py
deleted
100644 → 0
View file @
e61588cd
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MaskRCNN task."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision
import
beta
from
official.vision.beta.tasks
import
maskrcnn
class
RetinaNetTaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
"fasterrcnn_resnetfpn_coco"
,
True
),
(
"fasterrcnn_resnetfpn_coco"
,
False
),
(
"maskrcnn_resnetfpn_coco"
,
True
),
(
"maskrcnn_resnetfpn_coco"
,
False
),
)
def
test_retinanet_task_train
(
self
,
test_config
,
is_training
):
"""RetinaNet task test for training and val using toy configs."""
config
=
exp_factory
.
get_exp_config
(
test_config
)
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_bfloat16"
)
# modify config to suit local testing
config
.
trainer
.
steps_per_loop
=
1
config
.
task
.
train_data
.
global_batch_size
=
2
config
.
task
.
model
.
input_size
=
[
384
,
384
,
3
]
config
.
train_steps
=
2
config
.
task
.
train_data
.
shuffle_buffer_size
=
10
config
.
task
.
train_data
.
input_path
=
"coco/train-00000-of-00256.tfrecord"
config
.
task
.
validation_data
.
global_batch_size
=
2
config
.
task
.
validation_data
.
input_path
=
"coco/val-00000-of-00032.tfrecord"
task
=
maskrcnn
.
MaskRCNNTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
(
training
=
is_training
)
strategy
=
tf
.
distribute
.
get_strategy
()
data_config
=
config
.
task
.
train_data
if
is_training
else
config
.
task
.
validation_data
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
data_config
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
is_training
:
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
else
:
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/tasks/retinanet.py
View file @
b0ccdb11
...
@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task):
...
@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task):
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
"""Build RetinaNet losses."""
"""Build RetinaNet losses."""
params
=
self
.
task_config
params
=
self
.
task_config
cls_loss_fn
=
keras_cv
.
FocalLoss
(
cls_loss_fn
=
keras_cv
.
losses
.
FocalLoss
(
alpha
=
params
.
losses
.
focal_loss_alpha
,
alpha
=
params
.
losses
.
focal_loss_alpha
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
gamma
=
params
.
losses
.
focal_loss_gamma
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
reduction
=
tf
.
keras
.
losses
.
Reduction
.
SUM
)
...
@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task):
...
@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task):
num_positives
=
tf
.
reduce_sum
(
box_sample_weight
)
+
1.0
num_positives
=
tf
.
reduce_sum
(
box_sample_weight
)
+
1.0
cls_sample_weight
=
cls_sample_weight
/
num_positives
cls_sample_weight
=
cls_sample_weight
/
num_positives
box_sample_weight
=
box_sample_weight
/
num_positives
box_sample_weight
=
box_sample_weight
/
num_positives
y_true_cls
=
keras_cv
.
multi_level_flatten
(
y_true_cls
=
keras_cv
.
losses
.
multi_level_flatten
(
labels
[
'cls_targets'
],
last_dim
=
None
)
labels
[
'cls_targets'
],
last_dim
=
None
)
y_true_cls
=
tf
.
one_hot
(
y_true_cls
,
params
.
model
.
num_classes
)
y_true_cls
=
tf
.
one_hot
(
y_true_cls
,
params
.
model
.
num_classes
)
y_pred_cls
=
keras_cv
.
multi_level_flatten
(
y_pred_cls
=
keras_cv
.
losses
.
multi_level_flatten
(
outputs
[
'cls_outputs'
],
last_dim
=
params
.
model
.
num_classes
)
outputs
[
'cls_outputs'
],
last_dim
=
params
.
model
.
num_classes
)
y_true_box
=
keras_cv
.
multi_level_flatten
(
y_true_box
=
keras_cv
.
losses
.
multi_level_flatten
(
labels
[
'box_targets'
],
last_dim
=
4
)
labels
[
'box_targets'
],
last_dim
=
4
)
y_pred_box
=
keras_cv
.
multi_level_flatten
(
y_pred_box
=
keras_cv
.
losses
.
multi_level_flatten
(
outputs
[
'box_outputs'
],
last_dim
=
4
)
outputs
[
'box_outputs'
],
last_dim
=
4
)
cls_loss
=
cls_loss_fn
(
cls_loss
=
cls_loss_fn
(
...
...
official/vision/beta/tasks/retinanet_test.py
deleted
100644 → 0
View file @
e61588cd
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RetinaNet task."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision
import
beta
from
official.vision.beta.tasks
import
retinanet
class
RetinaNetTaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
"retinanet_resnetfpn_coco"
,
True
),
(
"retinanet_spinenet_coco"
,
True
),
)
def
test_retinanet_task_train
(
self
,
test_config
,
is_training
):
"""RetinaNet task test for training and val using toy configs."""
config
=
exp_factory
.
get_exp_config
(
test_config
)
# modify config to suit local testing
config
.
trainer
.
steps_per_loop
=
1
config
.
task
.
train_data
.
global_batch_size
=
2
config
.
task
.
validation_data
.
global_batch_size
=
2
config
.
task
.
train_data
.
shuffle_buffer_size
=
4
config
.
task
.
validation_data
.
shuffle_buffer_size
=
4
config
.
train_steps
=
2
task
=
retinanet
.
RetinaNetTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
data_config
=
config
.
task
.
train_data
if
is_training
else
config
.
task
.
validation_data
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
data_config
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
if
is_training
:
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
else
:
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/tasks/video_classification.py
0 → 100644
View file @
b0ccdb11
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification task definition."""
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.modeling
import
tf_utils
from
official.vision.beta.configs
import
video_classification
as
exp_cfg
from
official.vision.beta.dataloaders
import
video_input
from
official.vision.beta.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoClassificationTask
)
class
VideoClassificationTask
(
base_task
.
Task
):
"""A task for video classification."""
def
build_model
(
self
):
"""Builds video classification model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_video_classification_model
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
num_classes
=
self
.
task_config
.
train_data
.
num_classes
,
l2_regularizer
=
l2_regularizer
)
return
model
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
=
None
):
"""Builds classification input."""
decoder
=
video_input
.
Decoder
()
decoder_fn
=
decoder
.
decode
parser
=
video_input
.
Parser
(
input_params
=
params
)
postprocess_fn
=
video_input
.
PostBatchProcessor
(
params
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder_fn
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
losses_config
=
self
.
task_config
.
losses
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
if
self
.
task_config
.
losses
.
one_hot
:
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
1
,
name
=
'top_1_accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
)
]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
[
'image'
],
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
# Apply gradient clipping.
if
self
.
task_config
.
gradient_clip_norm
>
0
:
grads
,
_
=
tf
.
clip_by_global_norm
(
grads
,
self
.
task_config
.
gradient_clip_norm
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
[
'image'
],
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
inputs
,
model
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
official/vision/beta/train.py
View file @
b0ccdb11
...
@@ -19,13 +19,13 @@ from absl import app
...
@@ -19,13 +19,13 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -46,7 +46,7 @@ def main(_):
...
@@ -46,7 +46,7 @@ def main(_):
if
params
.
runtime
.
mixed_precision_dtype
:
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
...
...
official/vision/detection/dataloader/anchor.py
View file @
b0ccdb11
...
@@ -21,11 +21,11 @@ from __future__ import print_function
...
@@ -21,11 +21,11 @@ from __future__ import print_function
import
collections
import
collections
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision
import
keras_cv
from
official.vision.detection.utils.object_detection
import
argmax_matcher
from
official.vision.detection.utils.object_detection
import
argmax_matcher
from
official.vision.detection.utils.object_detection
import
balanced_positive_negative_sampler
from
official.vision.detection.utils.object_detection
import
balanced_positive_negative_sampler
from
official.vision.detection.utils.object_detection
import
box_list
from
official.vision.detection.utils.object_detection
import
box_list
from
official.vision.detection.utils.object_detection
import
faster_rcnn_box_coder
from
official.vision.detection.utils.object_detection
import
faster_rcnn_box_coder
from
official.vision.detection.utils.object_detection
import
region_similarity_calculator
from
official.vision.detection.utils.object_detection
import
target_assigner
from
official.vision.detection.utils.object_detection
import
target_assigner
...
@@ -134,7 +134,7 @@ class AnchorLabeler(object):
...
@@ -134,7 +134,7 @@ class AnchorLabeler(object):
upper-bound threshold to assign negative labels for anchors. An anchor
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
with a score below the threshold is labeled negative.
"""
"""
similarity_calc
=
region_similarity_calculator
.
IouSimilarity
()
similarity_calc
=
keras_cv
.
ops
.
IouSimilarity
()
matcher
=
argmax_matcher
.
ArgMaxMatcher
(
matcher
=
argmax_matcher
.
ArgMaxMatcher
(
match_threshold
,
match_threshold
,
unmatched_threshold
=
unmatched_threshold
,
unmatched_threshold
=
unmatched_threshold
,
...
...
official/vision/detection/main.py
View file @
b0ccdb11
...
@@ -14,28 +14,19 @@
...
@@ -14,28 +14,19 @@
# ==============================================================================
# ==============================================================================
"""Main function to train various object detection models."""
"""Main function to train various object detection models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
functools
import
pprint
import
pprint
# pylint: disable=g-bad-import-order
# Import libraries
import
tensorflow
as
tf
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
# pylint: enable=g-bad-import-order
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.training
import
distributed_executor
as
executor
from
official.modeling.training
import
distributed_executor
as
executor
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.dataloader
import
input_reader
from
official.vision.detection.dataloader
import
input_reader
...
@@ -87,9 +78,9 @@ def run_executor(params,
...
@@ -87,9 +78,9 @@ def run_executor(params,
strategy
=
prebuilt_strategy
strategy
=
prebuilt_strategy
else
:
else
:
strategy_config
=
params
.
strategy_config
strategy_config
=
params
.
strategy_config
distribut
ion
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
distribut
e
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy_config
.
task_index
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
strategy_type
,
distribution_strategy
=
params
.
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
num_gpus
=
strategy_config
.
num_gpus
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
...
...
official/vision/detection/utils/object_detection/target_assigner.py
View file @
b0ccdb11
...
@@ -151,8 +151,8 @@ class TargetAssigner(object):
...
@@ -151,8 +151,8 @@ class TargetAssigner(object):
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
with
tf
.
control_dependencies
(
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
match_quality_matrix
=
self
.
_similarity_calc
(
groundtruth_boxes
,
anchors
)
groundtruth_boxes
.
get
()
,
anchors
.
get
()
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
**
params
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
**
params
)
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
match
)
match
)
...
...
official/vision/image_classification/callbacks.py
View file @
b0ccdb11
...
@@ -29,16 +29,18 @@ from official.modeling import optimization
...
@@ -29,16 +29,18 @@ from official.modeling import optimization
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
def
get_callbacks
(
model_checkpoint
:
bool
=
True
,
def
get_callbacks
(
include_tensorboard
:
bool
=
True
,
model_checkpoint
:
bool
=
True
,
time_history
:
bool
=
True
,
include_tensorboard
:
bool
=
True
,
track_lr
:
bool
=
True
,
time_history
:
bool
=
True
,
write_model_weights
:
bool
=
True
,
track_lr
:
bool
=
True
,
apply_moving_average
:
bool
=
False
,
write_model_weights
:
bool
=
True
,
initial_step
:
int
=
0
,
apply_moving_average
:
bool
=
False
,
batch_size
:
int
=
0
,
initial_step
:
int
=
0
,
log_steps
:
int
=
0
,
batch_size
:
int
=
0
,
model_dir
:
str
=
None
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
log_steps
:
int
=
0
,
model_dir
:
str
=
None
,
backup_and_restore
:
bool
=
False
)
->
List
[
tf
.
keras
.
callbacks
.
Callback
]:
"""Get all callbacks."""
"""Get all callbacks."""
model_dir
=
model_dir
or
''
model_dir
=
model_dir
or
''
callbacks
=
[]
callbacks
=
[]
...
@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
.
append
(
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
backup_and_restore
:
backup_dir
=
os
.
path
.
join
(
model_dir
,
'tmp'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
experimental
.
BackupAndRestore
(
backup_dir
))
if
include_tensorboard
:
if
include_tensorboard
:
callbacks
.
append
(
callbacks
.
append
(
CustomTensorBoard
(
CustomTensorBoard
(
...
...
official/vision/image_classification/classifier_trainer.py
View file @
b0ccdb11
...
@@ -23,11 +23,10 @@ from absl import app
...
@@ -23,11 +23,10 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
dataset_factory
from
official.vision.image_classification
import
dataset_factory
...
@@ -291,17 +290,17 @@ def train_and_eval(
...
@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
"""Runs the train and eval path using compile/fit."""
logging
.
info
(
'Running train and eval.'
)
logging
.
info
(
'Running train and eval.'
)
distribut
ion
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
distribut
e
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
params
.
runtime
.
task_index
)
# Note: for TPUs, strategy and scope should be created before the dataset
# Note: for TPUs, strategy and scope should be created before the dataset
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
logging
.
info
(
'Detected %d devices.'
,
logging
.
info
(
'Detected %d devices.'
,
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
...
@@ -369,7 +368,8 @@ def train_and_eval(
...
@@ -369,7 +368,8 @@ def train_and_eval(
initial_step
=
initial_epoch
*
train_steps
,
initial_step
=
initial_epoch
*
train_steps
,
batch_size
=
train_builder
.
global_batch_size
,
batch_size
=
train_builder
.
global_batch_size
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
log_steps
=
params
.
train
.
time_history
.
log_steps
,
model_dir
=
params
.
model_dir
)
model_dir
=
params
.
model_dir
,
backup_and_restore
=
params
.
train
.
callbacks
.
enable_backup_and_restore
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
serialize_config
(
params
=
params
,
model_dir
=
params
.
model_dir
)
...
...
official/vision/image_classification/mnist_main.py
View file @
b0ccdb11
...
@@ -25,9 +25,8 @@ from absl import flags
...
@@ -25,9 +25,8 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
from
official.common
import
distribute_utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
...
@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
...
@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Returns:
Dictionary of training and eval stats.
Dictionary of training and eval stats.
"""
"""
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_gpus
=
flags_obj
.
num_gpus
,
tpu_address
=
flags_obj
.
tpu
)
tpu_address
=
flags_obj
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
if
flags_obj
.
download
:
if
flags_obj
.
download
:
...
...
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
b0ccdb11
...
@@ -23,10 +23,9 @@ from absl import flags
...
@@ -23,10 +23,9 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
...
@@ -117,7 +116,7 @@ def run(flags_obj):
...
@@ -117,7 +116,7 @@ def run(flags_obj):
else
'channels_last'
)
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_gpus
=
flags_obj
.
num_gpus
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
...
@@ -144,7 +143,7 @@ def run(flags_obj):
...
@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj
.
batch_size
,
flags_obj
.
batch_size
,
flags_obj
.
log_steps
,
flags_obj
.
log_steps
,
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
with
distribut
ion
_utils
.
get_strategy_scope
(
strategy
):
with
distribut
e
_utils
.
get_strategy_scope
(
strategy
):
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
per_epoch_steps
)
per_epoch_steps
)
...
...
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment