Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f8b5b27
Unverified
Commit
1f8b5b27
authored
Sep 03, 2021
by
Simon Geisler
Committed by
GitHub
Sep 03, 2021
Browse files
Merge branch 'master' into master
parents
0eeeaf98
8fcf177e
Changes
99
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1513 additions
and
34 deletions
+1513
-34
official/vision/beta/projects/video_ssl/losses/losses.py
official/vision/beta/projects/video_ssl/losses/losses.py
+136
-0
official/vision/beta/projects/video_ssl/modeling/video_ssl_model.py
...ision/beta/projects/video_ssl/modeling/video_ssl_model.py
+180
-0
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
...n/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
+406
-0
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
...a/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
+47
-0
official/vision/beta/projects/video_ssl/tasks/__init__.py
official/vision/beta/projects/video_ssl/tasks/__init__.py
+4
-0
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
+71
-0
official/vision/beta/projects/video_ssl/tasks/pretrain.py
official/vision/beta/projects/video_ssl/tasks/pretrain.py
+186
-0
official/vision/beta/projects/video_ssl/tasks/pretrain_test.py
...ial/vision/beta/projects/video_ssl/tasks/pretrain_test.py
+82
-0
official/vision/beta/projects/video_ssl/train.py
official/vision/beta/projects/video_ssl/train.py
+78
-0
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d.py
...jects/volumetric_models/tasks/semantic_segmentation_3d.py
+3
-3
official/vision/beta/serving/export_base.py
official/vision/beta/serving/export_base.py
+11
-0
official/vision/beta/serving/export_tflite.py
official/vision/beta/serving/export_tflite.py
+89
-0
official/vision/beta/serving/export_tflite_lib.py
official/vision/beta/serving/export_tflite_lib.py
+114
-0
official/vision/beta/serving/export_tflite_lib_test.py
official/vision/beta/serving/export_tflite_lib_test.py
+76
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+3
-3
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+10
-7
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+10
-7
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+3
-3
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+3
-3
official/vision/detection/ops/postprocess_ops.py
official/vision/detection/ops/postprocess_ops.py
+1
-8
No files found.
official/vision/beta/projects/video_ssl/losses/losses.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
import
tensorflow
as
tf
from
tensorflow.compiler.tf2xla.python
import
xla
def
contrastive_loss
(
hidden
,
num_replicas
,
normalize_hidden
,
temperature
,
model
,
weight_decay
):
"""Computes contrastive loss.
Args:
hidden: embedding of video clips after projection head.
num_replicas: number of distributed replicas.
normalize_hidden: whether or not to l2 normalize the hidden vector.
temperature: temperature in the InfoNCE contrastive loss.
model: keras model for calculating weight decay.
weight_decay: weight decay parameter.
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
large_num
=
1e9
hidden1
,
hidden2
=
tf
.
split
(
hidden
,
num_or_size_splits
=
2
,
axis
=
0
)
if
normalize_hidden
:
hidden1
=
tf
.
math
.
l2_normalize
(
hidden1
,
-
1
)
hidden2
=
tf
.
math
.
l2_normalize
(
hidden2
,
-
1
)
batch_size
=
tf
.
shape
(
hidden1
)[
0
]
if
num_replicas
==
1
:
# This is the local version
hidden1_large
=
hidden1
hidden2_large
=
hidden2
labels
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
*
2
)
masks
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
)
else
:
# This is the cross-tpu version.
hidden1_large
=
tpu_cross_replica_concat
(
hidden1
,
num_replicas
)
hidden2_large
=
tpu_cross_replica_concat
(
hidden2
,
num_replicas
)
enlarged_batch_size
=
tf
.
shape
(
hidden1_large
)[
0
]
replica_id
=
tf
.
cast
(
tf
.
cast
(
xla
.
replica_id
(),
tf
.
uint32
),
tf
.
int32
)
labels_idx
=
tf
.
range
(
batch_size
)
+
replica_id
*
batch_size
labels
=
tf
.
one_hot
(
labels_idx
,
enlarged_batch_size
*
2
)
masks
=
tf
.
one_hot
(
labels_idx
,
enlarged_batch_size
)
logits_aa
=
tf
.
matmul
(
hidden1
,
hidden1_large
,
transpose_b
=
True
)
/
temperature
logits_aa
=
logits_aa
-
tf
.
cast
(
masks
,
logits_aa
.
dtype
)
*
large_num
logits_bb
=
tf
.
matmul
(
hidden2
,
hidden2_large
,
transpose_b
=
True
)
/
temperature
logits_bb
=
logits_bb
-
tf
.
cast
(
masks
,
logits_bb
.
dtype
)
*
large_num
logits_ab
=
tf
.
matmul
(
hidden1
,
hidden2_large
,
transpose_b
=
True
)
/
temperature
logits_ba
=
tf
.
matmul
(
hidden2
,
hidden1_large
,
transpose_b
=
True
)
/
temperature
loss_a
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ab
,
logits_aa
],
1
)))
loss_b
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ba
,
logits_bb
],
1
)))
loss
=
loss_a
+
loss_b
l2_loss
=
weight_decay
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
for
v
in
model
.
trainable_variables
if
'kernel'
in
v
.
name
])
total_loss
=
loss
+
tf
.
cast
(
l2_loss
,
loss
.
dtype
)
contrast_prob
=
tf
.
nn
.
softmax
(
logits_ab
)
contrast_entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
contrast_prob
*
tf
.
math
.
log
(
contrast_prob
+
1e-8
),
-
1
))
contrast_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
1
),
tf
.
argmax
(
logits_ab
,
axis
=
1
))
contrast_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
contrast_acc
,
tf
.
float32
))
return
{
'total_loss'
:
total_loss
,
'contrastive_loss'
:
loss
,
'reg_loss'
:
l2_loss
,
'contrast_acc'
:
contrast_acc
,
'contrast_entropy'
:
contrast_entropy
,
}
def
tpu_cross_replica_concat
(
tensor
,
num_replicas
):
"""Reduce a concatenation of the `tensor` across TPU cores.
Args:
tensor: tensor to concatenate.
num_replicas: number of TPU device replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
with
tf
.
name_scope
(
'tpu_cross_replica_concat'
):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor
=
tf
.
scatter_nd
(
indices
=
[[
xla
.
replica_id
()]],
updates
=
[
tensor
],
shape
=
[
num_replicas
]
+
tensor
.
shape
.
as_list
())
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
replica_context
=
tf
.
distribute
.
get_replica_context
()
ext_tensor
=
replica_context
.
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
ext_tensor
)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return
tf
.
reshape
(
ext_tensor
,
[
-
1
]
+
ext_tensor
.
shape
.
as_list
()[
2
:])
official/vision/beta/projects/video_ssl/modeling/video_ssl_model.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build video classification models."""
from
typing
import
Mapping
,
Optional
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
video_ssl_cfg
layers
=
tf
.
keras
.
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
VideoSSLModel
(
tf
.
keras
.
Model
):
"""A video ssl model class builder."""
def
__init__
(
self
,
backbone
,
normalize_feature
,
hidden_dim
,
hidden_layer_num
,
hidden_norm_args
,
projection_dim
,
input_specs
:
Optional
[
Mapping
[
str
,
tf
.
keras
.
layers
.
InputSpec
]]
=
None
,
dropout_rate
:
float
=
0.0
,
aggregate_endpoints
:
bool
=
False
,
kernel_initializer
=
'random_uniform'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""Video Classification initialization function.
Args:
backbone: a 3d backbone network.
normalize_feature: whether normalize backbone feature.
hidden_dim: `int` number of hidden units in MLP.
hidden_layer_num: `int` number of hidden layers in MLP.
hidden_norm_args: `dict` for batchnorm arguments in MLP.
projection_dim: `int` number of ouput dimension for MLP.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
dropout_rate: `float` rate for dropout regularization.
aggregate_endpoints: `bool` aggregate all end ponits or only use the
final end point.
kernel_initializer: kernel initializer for the dense layer.
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
**kwargs: keyword arguments to be passed.
"""
if
not
input_specs
:
input_specs
=
{
'image'
:
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
])
}
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'normalize_feature'
:
normalize_feature
,
'hidden_dim'
:
hidden_dim
,
'hidden_layer_num'
:
hidden_layer_num
,
'use_sync_bn'
:
hidden_norm_args
.
use_sync_bn
,
'norm_momentum'
:
hidden_norm_args
.
norm_momentum
,
'norm_epsilon'
:
hidden_norm_args
.
norm_epsilon
,
'activation'
:
hidden_norm_args
.
activation
,
'projection_dim'
:
projection_dim
,
'input_specs'
:
input_specs
,
'dropout_rate'
:
dropout_rate
,
'aggregate_endpoints'
:
aggregate_endpoints
,
'kernel_initializer'
:
kernel_initializer
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
self
.
_input_specs
=
input_specs
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_backbone
=
backbone
inputs
=
{
k
:
tf
.
keras
.
Input
(
shape
=
v
.
shape
[
1
:])
for
k
,
v
in
input_specs
.
items
()
}
endpoints
=
backbone
(
inputs
[
'image'
])
if
aggregate_endpoints
:
pooled_feats
=
[]
for
endpoint
in
endpoints
.
values
():
x_pool
=
tf
.
keras
.
layers
.
GlobalAveragePooling3D
()(
endpoint
)
pooled_feats
.
append
(
x_pool
)
x
=
tf
.
concat
(
pooled_feats
,
axis
=
1
)
else
:
x
=
endpoints
[
max
(
endpoints
.
keys
())]
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling3D
()(
x
)
# L2 Normalize feature after backbone
if
normalize_feature
:
x
=
tf
.
nn
.
l2_normalize
(
x
,
axis
=-
1
)
# MLP hidden layers
for
_
in
range
(
hidden_layer_num
):
x
=
tf
.
keras
.
layers
.
Dense
(
hidden_dim
)(
x
)
if
self
.
_config_dict
[
'use_sync_bn'
]:
x
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
(
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])(
x
)
else
:
x
=
tf
.
keras
.
layers
.
BatchNormalization
(
momentum
=
self
.
_config_dict
[
'norm_momentum'
],
epsilon
=
self
.
_config_dict
[
'norm_epsilon'
])(
x
)
x
=
tf_utils
.
get_activation
(
self
.
_config_dict
[
'activation'
])(
x
)
# Projection head
x
=
tf
.
keras
.
layers
.
Dense
(
projection_dim
)(
x
)
super
(
VideoSSLModel
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
x
,
**
kwargs
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
backbone
)
@
property
def
backbone
(
self
):
return
self
.
_backbone
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
model_factory
.
register_model_builder
(
'video_ssl_model'
)
def
build_video_ssl_pretrain_model
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
video_ssl_cfg
.
VideoSSLModel
,
num_classes
:
int
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
):
"""Builds the video classification model."""
del
num_classes
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
,
l2_regularizer
=
l2_regularizer
)
# Norm layer type in the MLP head should same with backbone
assert
model_config
.
norm_activation
.
use_sync_bn
==
model_config
.
hidden_norm_activation
.
use_sync_bn
model
=
VideoSSLModel
(
backbone
=
backbone
,
normalize_feature
=
model_config
.
normalize_feature
,
hidden_dim
=
model_config
.
hidden_dim
,
hidden_layer_num
=
model_config
.
hidden_layer_num
,
hidden_norm_args
=
model_config
.
hidden_norm_activation
,
projection_dim
=
model_config
.
projection_dim
,
input_specs
=
input_specs_dict
,
dropout_rate
=
model_config
.
dropout_rate
,
aggregate_endpoints
=
model_config
.
aggregate_endpoints
,
kernel_regularizer
=
l2_regularizer
)
return
model
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for customed ops for video ssl."""
import
functools
from
typing
import
Optional
import
tensorflow
as
tf
def
random_apply
(
func
,
p
,
x
):
"""Randomly apply function func to x with probability p."""
return
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
),
tf
.
cast
(
p
,
tf
.
float32
)),
lambda
:
func
(
x
),
lambda
:
x
)
def
random_brightness
(
image
,
max_delta
):
"""Distort brightness of image (SimCLRv2 style)."""
factor
=
tf
.
random
.
uniform
(
[],
tf
.
maximum
(
1.0
-
max_delta
,
0
),
1.0
+
max_delta
)
image
=
image
*
factor
return
image
def
random_solarization
(
image
,
p
=
0.2
):
"""Random solarize image."""
def
_transform
(
image
):
image
=
image
*
tf
.
cast
(
tf
.
less
(
image
,
0.5
),
dtype
=
image
.
dtype
)
+
(
1.0
-
image
)
*
tf
.
cast
(
tf
.
greater_equal
(
image
,
0.5
),
dtype
=
image
.
dtype
)
return
image
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
to_grayscale
(
image
,
keep_channels
=
True
):
"""Turn the input image to gray scale.
Args:
image: The input image tensor.
keep_channels: Whether maintaining the channel number for the image.
If true, the transformed image will repeat three times in channel.
If false, the transformed image will only have one channel.
Returns:
The distorted image tensor.
"""
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
if
keep_channels
:
image
=
tf
.
tile
(
image
,
[
1
,
1
,
3
])
return
image
def
color_jitter
(
image
,
strength
,
random_order
=
True
):
"""Distorts the color of the image (SimCLRv2 style).
Args:
image: The input image tensor.
strength: The floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
Returns:
The distorted image tensor.
"""
brightness
=
0.8
*
strength
contrast
=
0.8
*
strength
saturation
=
0.8
*
strength
hue
=
0.2
*
strength
if
random_order
:
return
color_jitter_rand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
else
:
return
color_jitter_nonrand
(
image
,
brightness
,
contrast
,
saturation
,
hue
)
def
color_jitter_nonrand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is fixed, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
,
brightness
,
contrast
,
saturation
,
hue
):
"""Apply the i-th transformation."""
if
brightness
!=
0
and
i
==
0
:
x
=
random_brightness
(
x
,
max_delta
=
brightness
)
elif
contrast
!=
0
and
i
==
1
:
x
=
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
elif
saturation
!=
0
and
i
==
2
:
x
=
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
elif
hue
!=
0
:
x
=
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
return
x
for
i
in
range
(
4
):
image
=
apply_transform
(
i
,
image
,
brightness
,
contrast
,
saturation
,
hue
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter_rand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
"""Distorts the color of the image (jittering order is random, SimCLRv2 style).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
):
"""Apply the i-th transformation."""
def
brightness_transform
():
if
brightness
==
0
:
return
x
else
:
return
random_brightness
(
x
,
max_delta
=
brightness
)
def
contrast_transform
():
if
contrast
==
0
:
return
x
else
:
return
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
def
saturation_transform
():
if
saturation
==
0
:
return
x
else
:
return
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
def
hue_transform
():
if
hue
==
0
:
return
x
else
:
return
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
# pylint:disable=g-long-lambda
x
=
tf
.
cond
(
tf
.
less
(
i
,
2
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
1
),
brightness_transform
,
contrast_transform
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
3
),
saturation_transform
,
hue_transform
))
# pylint:disable=g-long-lambda
return
x
perm
=
tf
.
random
.
shuffle
(
tf
.
range
(
4
))
for
i
in
range
(
4
):
image
=
apply_transform
(
perm
[
i
],
image
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
random_color_jitter_3d
(
frames
):
"""Applies temporally consistent color jittering to one video clip.
Args:
frames: `Tensor` of shape [num_frames, height, width, channels].
Returns:
A Tensor of shape [num_frames, height, width, channels] being color jittered
with the same operation.
"""
def
random_color_jitter
(
image
,
p
=
1.0
):
def
_transform
(
image
):
color_jitter_t
=
functools
.
partial
(
color_jitter
,
strength
=
1.0
)
image
=
random_apply
(
color_jitter_t
,
p
=
0.8
,
x
=
image
)
return
random_apply
(
to_grayscale
,
p
=
0.2
,
x
=
image
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
num_frames
,
width
,
height
,
channels
=
frames
.
shape
.
as_list
()
big_image
=
tf
.
reshape
(
frames
,
[
num_frames
*
width
,
height
,
channels
])
big_image
=
random_color_jitter
(
big_image
)
return
tf
.
reshape
(
big_image
,
[
num_frames
,
width
,
height
,
channels
])
def
gaussian_blur
(
image
,
kernel_size
,
sigma
,
padding
=
'SAME'
):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius
=
tf
.
cast
(
kernel_size
/
2
,
dtype
=
tf
.
int32
)
kernel_size
=
radius
*
2
+
1
x
=
tf
.
cast
(
tf
.
range
(
-
radius
,
radius
+
1
),
dtype
=
tf
.
float32
)
blur_filter
=
tf
.
exp
(
-
tf
.
pow
(
x
,
2.0
)
/
(
2.0
*
tf
.
pow
(
tf
.
cast
(
sigma
,
dtype
=
tf
.
float32
),
2.0
)))
blur_filter
/=
tf
.
reduce_sum
(
blur_filter
)
# One vertical and one horizontal filter.
blur_v
=
tf
.
reshape
(
blur_filter
,
[
kernel_size
,
1
,
1
,
1
])
blur_h
=
tf
.
reshape
(
blur_filter
,
[
1
,
kernel_size
,
1
,
1
])
num_channels
=
tf
.
shape
(
image
)[
-
1
]
blur_h
=
tf
.
tile
(
blur_h
,
[
1
,
1
,
num_channels
,
1
])
blur_v
=
tf
.
tile
(
blur_v
,
[
1
,
1
,
num_channels
,
1
])
expand_batch_dim
=
image
.
shape
.
ndims
==
3
if
expand_batch_dim
:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
image
,
blur_h
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
blurred
,
blur_v
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
if
expand_batch_dim
:
blurred
=
tf
.
squeeze
(
blurred
,
axis
=
0
)
return
blurred
def
random_blur
(
image
,
height
,
width
,
p
=
1.0
):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del
width
def
_transform
(
image
):
sigma
=
tf
.
random
.
uniform
([],
0.1
,
2.0
,
dtype
=
tf
.
float32
)
return
gaussian_blur
(
image
,
kernel_size
=
height
//
10
,
sigma
=
sigma
,
padding
=
'SAME'
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
random_blur_3d
(
frames
,
height
,
width
,
blur_probability
=
0.5
):
"""Apply efficient batch data transformations.
Args:
frames: `Tensor` of shape [timesteps, height, width, 3].
height: the height of image.
width: the width of image.
blur_probability: the probaility to apply the blur operator.
Returns:
Preprocessed feature list.
"""
def
generate_selector
(
p
,
bsz
):
shape
=
[
bsz
,
1
,
1
,
1
]
selector
=
tf
.
cast
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
,
0
,
1
,
dtype
=
tf
.
float32
),
p
),
tf
.
float32
)
return
selector
frames_new
=
random_blur
(
frames
,
height
,
width
,
p
=
1.
)
selector
=
generate_selector
(
blur_probability
,
1
)
frames
=
frames_new
*
selector
+
frames
*
(
1
-
selector
)
frames
=
tf
.
clip_by_value
(
frames
,
0.
,
1.
)
return
frames
def
_sample_or_pad_sequence_indices
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
stride
:
int
,
offset
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Returns indices to take for sampling or padding sequences to fixed size."""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sel_idx
=
tf
.
range
(
sequence_length
)
# Repeats sequence until num_steps are available in total.
max_length
=
num_steps
*
stride
+
offset
num_repeats
=
tf
.
math
.
floordiv
(
max_length
+
sequence_length
-
1
,
sequence_length
)
sel_idx
=
tf
.
tile
(
sel_idx
,
[
num_repeats
])
steps
=
tf
.
range
(
offset
,
offset
+
num_steps
*
stride
,
stride
)
return
tf
.
gather
(
sel_idx
,
steps
)
def
sample_ssl_sequence
(
sequence
:
tf
.
Tensor
,
num_steps
:
int
,
random
:
bool
,
stride
:
int
=
1
,
num_windows
:
Optional
[
int
]
=
2
)
->
tf
.
Tensor
:
"""Samples two segments of size num_steps randomly from a given sequence.
Currently it only supports images, and specically designed for video self-
supervised learning.
Args:
sequence: Any tensor where the first dimension is timesteps.
num_steps: Number of steps (e.g. frames) to take.
random: A boolean indicating whether to random sample the single window. If
True, the offset is randomized. Only True is supported.
stride: Distance to sample between timesteps.
num_windows: Number of sequence sampled.
Returns:
A single Tensor with first dimension num_steps with the sampled segment.
"""
sequence_length
=
tf
.
shape
(
sequence
)[
0
]
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
float32
)
if
random
:
max_offset
=
tf
.
cond
(
tf
.
greater
(
sequence_length
,
(
num_steps
-
1
)
*
stride
),
lambda
:
sequence_length
-
(
num_steps
-
1
)
*
stride
,
lambda
:
sequence_length
)
max_offset
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
float32
)
def
cdf
(
k
,
power
=
1.0
):
"""Cumulative distribution function for x^power."""
p
=
-
tf
.
math
.
pow
(
k
,
power
+
1
)
/
(
power
*
tf
.
math
.
pow
(
max_offset
,
power
+
1
))
+
k
*
(
power
+
1
)
/
(
power
*
max_offset
)
return
p
u
=
tf
.
random
.
uniform
(())
k_low
=
tf
.
constant
(
0
,
dtype
=
tf
.
float32
)
k_up
=
max_offset
k
=
tf
.
math
.
floordiv
(
max_offset
,
2.0
)
c
=
lambda
k_low
,
k_up
,
k
:
tf
.
greater
(
tf
.
math
.
abs
(
k_up
-
k_low
),
1.0
)
# pylint:disable=g-long-lambda
b
=
lambda
k_low
,
k_up
,
k
:
tf
.
cond
(
tf
.
greater
(
cdf
(
k
),
u
),
lambda
:
[
k_low
,
k
,
tf
.
math
.
floordiv
(
k
+
k_low
,
2.0
)],
lambda
:
[
k
,
k_up
,
tf
.
math
.
floordiv
(
k_up
+
k
,
2.0
)])
_
,
_
,
k
=
tf
.
while_loop
(
c
,
b
,
[
k_low
,
k_up
,
k
])
delta
=
tf
.
cast
(
k
,
tf
.
int32
)
max_offset
=
tf
.
cast
(
max_offset
,
tf
.
int32
)
sequence_length
=
tf
.
cast
(
sequence_length
,
tf
.
int32
)
choice_1
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
-
delta
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
))
choice_2
=
tf
.
cond
(
tf
.
equal
(
max_offset
,
sequence_length
),
lambda
:
tf
.
random
.
uniform
((),
maxval
=
tf
.
cast
(
max_offset
,
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
lambda
:
choice_1
+
delta
)
# pylint:disable=g-long-lambda
shuffle_choice
=
tf
.
random
.
shuffle
((
choice_1
,
choice_2
))
offset_1
=
shuffle_choice
[
0
]
offset_2
=
shuffle_choice
[
1
]
else
:
raise
NotImplementedError
indices_1
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_1
)
indices_2
=
_sample_or_pad_sequence_indices
(
sequence
=
sequence
,
num_steps
=
num_steps
,
stride
=
stride
,
offset
=
offset_2
)
indices
=
tf
.
concat
([
indices_1
,
indices_2
],
axis
=
0
)
indices
.
set_shape
((
num_windows
*
num_steps
,))
output
=
tf
.
gather
(
sequence
,
indices
)
return
output
official/vision/beta/projects/video_ssl/ops/video_ssl_preprocess_ops_test.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tensorflow
as
tf
from
official.vision.beta.ops
import
preprocess_ops_3d
from
official.vision.beta.projects.video_ssl.ops
import
video_ssl_preprocess_ops
class
VideoSslPreprocessOpsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_raw_frames
=
tf
.
random
.
uniform
((
250
,
256
,
256
,
3
),
minval
=
0
,
maxval
=
255
,
dtype
=
tf
.
dtypes
.
int32
)
self
.
_sampled_frames
=
self
.
_raw_frames
[:
16
]
self
.
_frames
=
preprocess_ops_3d
.
normalize_image
(
self
.
_sampled_frames
,
False
,
tf
.
float32
)
def
test_sample_ssl_sequence
(
self
):
sampled_seq
=
video_ssl_preprocess_ops
.
sample_ssl_sequence
(
self
.
_raw_frames
,
16
,
True
,
2
)
self
.
assertAllEqual
(
sampled_seq
.
shape
,
(
32
,
256
,
256
,
3
))
def
test_random_color_jitter_3d
(
self
):
jittered_clip
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
self
.
_frames
)
self
.
assertAllEqual
(
jittered_clip
.
shape
,
(
16
,
256
,
256
,
3
))
def
test_random_blur_3d
(
self
):
blurred_clip
=
video_ssl_preprocess_ops
.
random_blur_3d
(
self
.
_frames
,
256
,
256
)
self
.
assertAllEqual
(
blurred_clip
.
shape
,
(
16
,
256
,
256
,
3
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/
staging/training
/__init__.py
→
official/
vision/beta/projects/video_ssl/tasks
/__init__.py
View file @
1f8b5b27
...
@@ -12,3 +12,7 @@
...
@@ -12,3 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tasks package definition."""
from
official.vision.beta.projects.video_ssl.tasks
import
linear_eval
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
official/vision/beta/projects/video_ssl/tasks/linear_eval.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl linear evaluation task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
task_factory
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.tasks
import
video_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoSSLEvalTask
)
class
VideoSSLEvalTask
(
video_classification
.
VideoClassificationTask
):
"""A task for video ssl linear evaluation."""
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
.
read
(
ckpt_dir_or_file
)
else
:
raise
NotImplementedError
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
model
.
backbone
.
trainable
=
False
logging
.
info
(
'Setting the backbone to non-trainable.'
)
return
super
(
video_classification
.
VideoClassificationTask
,
self
).
train_step
(
inputs
,
model
,
optimizer
,
metrics
)
official/vision/beta/projects/video_ssl/tasks/pretrain.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl pretrain task definition."""
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.vision.beta.modeling
import
factory_3d
from
official.vision.beta.projects.video_ssl.configs
import
video_ssl
as
exp_cfg
from
official.vision.beta.projects.video_ssl.dataloaders
import
video_ssl_input
from
official.vision.beta.projects.video_ssl.losses
import
losses
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.tasks
import
video_classification
@
task_factory
.
register_task_cls
(
exp_cfg
.
VideoSSLPretrainTask
)
class
VideoSSLPretrainTask
(
video_classification
.
VideoClassificationTask
):
"""A task for video ssl pretraining."""
def
build_model
(
self
):
"""Builds video ssl pretraining model."""
common_input_shape
=
[
d1
if
d1
==
d2
else
None
for
d1
,
d2
in
zip
(
self
.
task_config
.
train_data
.
feature_shape
,
self
.
task_config
.
validation_data
.
feature_shape
)
]
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
common_input_shape
)
logging
.
info
(
'Build model input %r'
,
common_input_shape
)
model
=
factory_3d
.
build_model
(
self
.
task_config
.
model
.
model_type
,
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
num_classes
=
self
.
task_config
.
train_data
.
num_classes
)
return
model
def
_get_decoder_fn
(
self
,
params
):
decoder
=
video_ssl_input
.
Decoder
()
return
decoder
.
decode
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
=
None
):
"""Builds classification input."""
parser
=
video_ssl_input
.
Parser
(
input_params
=
params
)
postprocess_fn
=
video_ssl_input
.
PostBatchProcessor
(
params
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
self
.
_get_dataset_fn
(
params
),
decoder_fn
=
self
.
_get_decoder_fn
(
params
),
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
model_outputs
,
num_replicas
,
model
):
"""Sparse categorical cross entropy loss.
Args:
model_outputs: Output logits of the model.
num_replicas: distributed replica number.
model: keras model for calculating weight decay.
Returns:
The total loss tensor.
"""
all_losses
=
{}
contrastive_metrics
=
{}
losses_config
=
self
.
task_config
.
losses
total_loss
=
None
contrastive_loss_dict
=
losses
.
contrastive_loss
(
model_outputs
,
num_replicas
,
losses_config
.
normalize_hidden
,
losses_config
.
temperature
,
model
,
self
.
task_config
.
losses
.
l2_weight_decay
)
total_loss
=
contrastive_loss_dict
[
'total_loss'
]
all_losses
.
update
({
'total_loss'
:
total_loss
})
all_losses
[
self
.
loss
]
=
total_loss
contrastive_metrics
.
update
({
'contrast_acc'
:
contrastive_loss_dict
[
'contrast_acc'
],
'contrast_entropy'
:
contrastive_loss_dict
[
'contrast_entropy'
],
'reg_loss'
:
contrastive_loss_dict
[
'reg_loss'
]
})
return
all_losses
,
contrastive_metrics
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
metrics
=
[
tf
.
keras
.
metrics
.
Mean
(
name
=
'contrast_acc'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'contrast_entropy'
),
tf
.
keras
.
metrics
.
Mean
(
name
=
'reg_loss'
)
]
return
metrics
def
process_metrics
(
self
,
metrics
,
contrastive_metrics
):
"""Process and update metrics."""
contrastive_metric_values
=
contrastive_metrics
.
values
()
for
metric
,
contrastive_metric_value
in
zip
(
metrics
,
contrastive_metric_values
):
metric
.
update_state
(
contrastive_metric_value
)
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
_
=
inputs
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
if
self
.
task_config
.
train_data
.
output_audio
:
outputs
=
model
(
features
,
training
=
True
)
else
:
outputs
=
model
(
features
[
'image'
],
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
all_losses
,
contrastive_metrics
=
self
.
build_losses
(
model_outputs
=
outputs
,
num_replicas
=
num_replicas
,
model
=
model
)
loss
=
all_losses
[
self
.
loss
]
scaled_loss
=
loss
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
all_losses
if
metrics
:
self
.
process_metrics
(
metrics
,
contrastive_metrics
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
raise
NotImplementedError
def
inference_step
(
self
,
features
,
model
):
"""Performs the forward step."""
raise
NotImplementedError
official/vision/beta/projects/video_ssl/tasks/pretrain_test.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import
functools
import
os
import
random
import
orbit
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.vision
import
beta
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
class
VideoClassificationTaskTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
VideoClassificationTaskTest
,
self
).
setUp
()
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
# pylint: disable=g-complex-comprehension
examples
=
[
tfexample_utils
.
make_video_test_example
(
image_shape
=
(
36
,
36
,
3
),
audio_shape
=
(
20
,
128
),
label
=
random
.
randint
(
0
,
100
))
for
_
in
range
(
2
)
]
# pylint: enable=g-complex-comprehension
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
def
test_task
(
self
):
config
=
exp_factory
.
get_exp_config
(
'video_ssl_pretrain_kinetics600'
)
config
.
task
.
train_data
.
global_batch_size
=
2
config
.
task
.
train_data
.
input_path
=
self
.
_data_path
task
=
pretrain
.
VideoSSLPretrainTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
functools
.
partial
(
task
.
build_inputs
),
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
self
.
assertIn
(
'total_loss'
,
logs
)
self
.
assertIn
(
'reg_loss'
,
logs
)
self
.
assertIn
(
'contrast_acc'
,
logs
)
self
.
assertIn
(
'contrast_entropy'
,
logs
)
def
test_task_factory
(
self
):
config
=
exp_factory
.
get_exp_config
(
'video_ssl_pretrain_kinetics600'
)
task
=
task_factory
.
get_task
(
config
.
task
)
self
.
assertIs
(
type
(
task
),
pretrain
.
VideoSSLPretrainTask
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/video_ssl/train.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta.projects.video_ssl.modeling
import
video_ssl_model
from
official.vision.beta.projects.video_ssl.tasks
import
linear_eval
from
official.vision.beta.projects.video_ssl.tasks
import
pretrain
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
if
'train_and_eval'
in
FLAGS
.
mode
:
assert
(
params
.
task
.
train_data
.
feature_shape
==
params
.
task
.
validation_data
.
feature_shape
),
(
f
'train
{
params
.
task
.
train_data
.
feature_shape
}
!= validate '
f
'
{
params
.
task
.
validation_data
.
feature_shape
}
'
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d.py
View file @
1f8b5b27
...
@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task):
...
@@ -79,8 +79,8 @@ class SemanticSegmentation3DTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
else
:
else
:
ckpt_items
=
{}
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
...
@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task):
...
@@ -89,7 +89,7 @@ class SemanticSegmentation3DTask(base_task.Task):
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
...
...
official/vision/beta/serving/export_base.py
View file @
1f8b5b27
...
@@ -103,6 +103,10 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
...
@@ -103,6 +103,10 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
return
self
.
serve
(
inputs
)
@
tf
.
function
def
inference_for_tflite
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
return
self
.
serve
(
inputs
)
@
tf
.
function
@
tf
.
function
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
def
inference_from_image_bytes
(
self
,
inputs
:
tf
.
Tensor
):
with
tf
.
device
(
'cpu:0'
):
with
tf
.
device
(
'cpu:0'
):
...
@@ -174,6 +178,13 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
...
@@ -174,6 +178,13 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
signatures
[
signatures
[
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
def_name
]
=
self
.
inference_from_tf_example
.
get_concrete_function
(
input_signature
)
input_signature
)
elif
key
==
'tflite'
:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
self
.
_num_channels
],
dtype
=
tf
.
float32
)
signatures
[
def_name
]
=
self
.
inference_for_tflite
.
get_concrete_function
(
input_signature
)
else
:
else
:
raise
ValueError
(
'Unrecognized `input_type`'
)
raise
ValueError
(
'Unrecognized `input_type`'
)
return
signatures
return
signatures
official/vision/beta/serving/export_tflite.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary to convert a saved model to tflite model."""
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.serving
import
export_tflite_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
,
required
=
True
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
''
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_string
(
'saved_model_dir'
,
None
,
'The directory to the saved model.'
,
required
=
True
)
flags
.
DEFINE_string
(
'tflite_path'
,
None
,
'The path to the output tflite model.'
,
required
=
True
)
flags
.
DEFINE_string
(
'quant_type'
,
default
=
None
,
help
=
'Post training quantization type. Support `int8`, `int8_full`, '
'`fp16`, and `default`. See '
'https://www.tensorflow.org/lite/performance/post_training_quantization '
'for more details.'
)
flags
.
DEFINE_integer
(
'calibration_steps'
,
500
,
'The number of calibration steps for integer model.'
)
def
main
(
_
)
->
None
:
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
if
FLAGS
.
config_file
is
not
None
:
for
config_file
in
FLAGS
.
config_file
:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
logging
.
info
(
'Converting SavedModel from %s to TFLite model...'
,
FLAGS
.
saved_model_dir
)
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
FLAGS
.
saved_model_dir
,
quant_type
=
FLAGS
.
quant_type
,
params
=
params
,
calibration_steps
=
FLAGS
.
calibration_steps
)
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
tflite_path
,
'wb'
)
as
fw
:
fw
.
write
(
tflite_model
)
logging
.
info
(
'TFLite model converted and saved to %s.'
,
FLAGS
.
tflite_path
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/serving/export_tflite_lib.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to facilitate TFLite model conversion."""
import
functools
from
typing
import
Iterator
,
List
,
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.vision.beta
import
configs
from
official.vision.beta.tasks
import
image_classification
as
img_cls_task
def
create_representative_dataset
(
params
:
cfg
.
ExperimentConfig
)
->
tf
.
data
.
Dataset
:
"""Creates a tf.data.Dataset to load images for representative dataset.
Args:
params: An ExperimentConfig.
Returns:
A tf.data.Dataset instance.
Raises:
ValueError: If task is not supported.
"""
if
isinstance
(
params
.
task
,
configs
.
image_classification
.
ImageClassificationTask
):
task
=
img_cls_task
.
ImageClassificationTask
(
params
.
task
)
else
:
raise
ValueError
(
'Task {} not supported.'
.
format
(
type
(
params
.
task
)))
# Ensure batch size is 1 for TFLite model.
params
.
task
.
train_data
.
global_batch_size
=
1
params
.
task
.
train_data
.
dtype
=
'float32'
logging
.
info
(
'Task config: %s'
,
params
.
task
.
as_dict
())
return
task
.
build_inputs
(
params
=
params
.
task
.
train_data
)
def
representative_dataset
(
params
:
cfg
.
ExperimentConfig
,
calibration_steps
:
int
=
2000
)
->
Iterator
[
List
[
tf
.
Tensor
]]:
""""Creates representative dataset for input calibration.
Args:
params: An ExperimentConfig.
calibration_steps: The steps to do calibration.
Yields:
An input image tensor.
"""
dataset
=
create_representative_dataset
(
params
=
params
)
for
image
,
_
in
dataset
.
take
(
calibration_steps
):
# Skip images that do not have 3 channels.
if
image
.
shape
[
-
1
]
!=
3
:
continue
yield
[
image
]
def
convert_tflite_model
(
saved_model_dir
:
str
,
quant_type
:
Optional
[
str
]
=
None
,
params
:
Optional
[
cfg
.
ExperimentConfig
]
=
None
,
calibration_steps
:
Optional
[
int
]
=
2000
)
->
bytes
:
"""Converts and returns a TFLite model.
Args:
saved_model_dir: The directory to the SavedModel.
quant_type: The post training quantization (PTQ) method. It can be one of
`default` (dynamic range), `fp16` (float16), `int8` (integer wih float
fallback), `int8_full` (integer only) and None (no quantization).
params: An optional ExperimentConfig to load and preprocess input images to
do calibration for integer quantization.
calibration_steps: The steps to do calibration.
Returns:
A converted TFLite model with optional PTQ.
Raises:
ValueError: If `representative_dataset_path` is not present if integer
quantization is requested.
"""
converter
=
tf
.
lite
.
TFLiteConverter
.
from_saved_model
(
saved_model_dir
)
if
quant_type
:
if
quant_type
.
startswith
(
'int8'
):
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
converter
.
representative_dataset
=
functools
.
partial
(
representative_dataset
,
params
=
params
,
calibration_steps
=
calibration_steps
)
if
quant_type
==
'int8_full'
:
converter
.
target_spec
.
supported_ops
=
[
tf
.
lite
.
OpsSet
.
TFLITE_BUILTINS_INT8
]
converter
.
inference_input_type
=
tf
.
uint8
# or tf.int8
converter
.
inference_output_type
=
tf
.
uint8
# or tf.int8
elif
quant_type
==
'fp16'
:
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
converter
.
target_spec
.
supported_types
=
[
tf
.
float16
]
elif
quant_type
==
'default'
:
converter
.
optimizations
=
[
tf
.
lite
.
Optimize
.
DEFAULT
]
return
converter
.
convert
()
official/vision/beta/serving/export_tflite_lib_test.py
0 → 100644
View file @
1f8b5b27
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for export_tflite_lib."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.serving
import
export_tflite_lib
from
official.vision.beta.serving
import
image_classification
as
image_classification_serving
class
ExportTfliteLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_test_tfrecord_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'test.tfrecord'
)
self
.
_create_test_tfrecord
(
num_samples
=
50
)
def
_create_test_tfrecord
(
self
,
num_samples
):
tfexample_utils
.
dump_to_tfrecord
(
self
.
_test_tfrecord_file
,
[
tf
.
train
.
Example
.
FromString
(
tfexample_utils
.
create_classification_example
(
image_height
=
256
,
image_width
=
256
))
for
_
in
range
(
num_samples
)
])
def
_export_from_module
(
self
,
module
,
input_type
,
saved_model_dir
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
saved_model_dir
,
signatures
=
signatures
)
@
combinations
.
generate
(
combinations
.
combine
(
experiment
=
[
'mobilenet_imagenet'
],
quant_type
=
[
None
,
'default'
,
'fp16'
,
'int8'
],
input_image_size
=
[[
224
,
224
]]))
def
test_export_tflite
(
self
,
experiment
,
quant_type
,
input_image_size
):
params
=
exp_factory
.
get_exp_config
(
experiment
)
params
.
task
.
validation_data
.
input_path
=
self
.
_test_tfrecord_file
params
.
task
.
train_data
.
input_path
=
self
.
_test_tfrecord_file
temp_dir
=
self
.
get_temp_dir
()
module
=
image_classification_serving
.
ClassificationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
)
self
.
_export_from_module
(
module
=
module
,
input_type
=
'tflite'
,
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
))
tflite_model
=
export_tflite_lib
.
convert_tflite_model
(
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'saved_model'
),
quant_type
=
quant_type
,
params
=
params
,
calibration_steps
=
5
)
self
.
assertIsInstance
(
tflite_model
,
bytes
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
1f8b5b27
...
@@ -63,11 +63,11 @@ class ImageClassificationTask(base_task.Task):
...
@@ -63,11 +63,11 @@ class ImageClassificationTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
official/vision/beta/tasks/maskrcnn.py
View file @
1f8b5b27
...
@@ -96,15 +96,18 @@ class MaskRCNNTask(base_task.Task):
...
@@ -96,15 +96,18 @@ class MaskRCNNTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
raise
ValueError
(
ckpt_items
=
{}
"Only 'all' or 'backbone' can be used to initialize the model."
)
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
ckpt_dir_or_file
)
...
...
official/vision/beta/tasks/retinanet.py
View file @
1f8b5b27
...
@@ -71,15 +71,18 @@ class RetinaNetTask(base_task.Task):
...
@@ -71,15 +71,18 @@ class RetinaNetTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
raise
ValueError
(
ckpt_items
=
{}
"Only 'all' or 'backbone' can be used to initialize the model."
)
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
ckpt_dir_or_file
)
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
1f8b5b27
...
@@ -63,8 +63,8 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -63,8 +63,8 @@ class SemanticSegmentationTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
else
:
else
:
ckpt_items
=
{}
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
...
@@ -73,7 +73,7 @@ class SemanticSegmentationTask(base_task.Task):
...
@@ -73,7 +73,7 @@ class SemanticSegmentationTask(base_task.Task):
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
...
...
official/vision/beta/tasks/video_classification.py
View file @
1f8b5b27
...
@@ -86,11 +86,11 @@ class VideoClassificationTask(base_task.Task):
...
@@ -86,11 +86,11 @@ class VideoClassificationTask(base_task.Task):
# Restoring checkpoint.
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
assert_consum
ed
()
status
.
expect_partial
().
assert_existing_objects_match
ed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
re
store
(
ckpt_dir_or_file
)
status
=
ckpt
.
re
ad
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
official/vision/detection/ops/postprocess_ops.py
View file @
1f8b5b27
...
@@ -270,11 +270,6 @@ def _generate_detections_batched(boxes, scores, max_total_size,
...
@@ -270,11 +270,6 @@ def _generate_detections_batched(boxes, scores, max_total_size,
`valid_detections` boxes are valid detections.
`valid_detections` boxes are valid detections.
"""
"""
with
tf
.
name_scope
(
'generate_detections'
):
with
tf
.
name_scope
(
'generate_detections'
):
# TODO(tsungyi): Removes normalization/denomalization once the
# tf.image.combined_non_max_suppression is coordinate system agnostic.
# Normalizes maximum box cooridinates to 1.
normalizer
=
tf
.
reduce_max
(
boxes
)
boxes
/=
normalizer
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
)
=
tf
.
image
.
combined_non_max_suppression
(
valid_detections
)
=
tf
.
image
.
combined_non_max_suppression
(
boxes
,
boxes
,
...
@@ -284,9 +279,7 @@ def _generate_detections_batched(boxes, scores, max_total_size,
...
@@ -284,9 +279,7 @@ def _generate_detections_batched(boxes, scores, max_total_size,
iou_threshold
=
nms_iou_threshold
,
iou_threshold
=
nms_iou_threshold
,
score_threshold
=
score_threshold
,
score_threshold
=
score_threshold
,
pad_per_class
=
False
,
pad_per_class
=
False
,
)
clip_boxes
=
False
)
# De-normalizes box cooridinates.
nmsed_boxes
*=
normalizer
nmsed_classes
=
tf
.
cast
(
nmsed_classes
,
tf
.
int32
)
nmsed_classes
=
tf
.
cast
(
nmsed_classes
,
tf
.
int32
)
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment