Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
57557684
Unverified
Commit
57557684
authored
Apr 29, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Apr 29, 2022
Browse files
Merge branch 'tensorflow:master' into panoptic-deeplab
parents
c4ce3a9e
2f9266ac
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
846 additions
and
10 deletions
+846
-10
official/projects/deepmac_maskrcnn/README.md
official/projects/deepmac_maskrcnn/README.md
+1
-1
official/projects/movinet/configs/movinet.py
official/projects/movinet/configs/movinet.py
+1
-0
official/projects/movinet/modeling/movinet.py
official/projects/movinet/modeling/movinet.py
+7
-1
official/projects/movinet/modeling/movinet_layers.py
official/projects/movinet/modeling/movinet_layers.py
+22
-5
official/projects/qat/vision/configs/common.py
official/projects/qat/vision/configs/common.py
+3
-0
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu.yaml
...experiments/retinanet/coco_spinenet49_mobile_qat_tpu.yaml
+66
-0
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
...riments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
+67
-0
official/projects/qat/vision/modeling/factory.py
official/projects/qat/vision/modeling/factory.py
+10
-1
official/projects/qat/vision/modeling/heads/__init__.py
official/projects/qat/vision/modeling/heads/__init__.py
+18
-0
official/projects/qat/vision/modeling/heads/dense_prediction_heads.py
...jects/qat/vision/modeling/heads/dense_prediction_heads.py
+428
-0
official/projects/qat/vision/modeling/heads/dense_prediction_heads_test.py
.../qat/vision/modeling/heads/dense_prediction_heads_test.py
+92
-0
official/projects/video_ssl/dataloaders/video_ssl_input.py
official/projects/video_ssl/dataloaders/video_ssl_input.py
+9
-2
official/utils/docs/build_tfm_api_docs.py
official/utils/docs/build_tfm_api_docs.py
+44
-0
official/vision/beta/__init__.py
official/vision/beta/__init__.py
+14
-0
official/vision/beta/projects/yolo/modeling/__init__.py
official/vision/beta/projects/yolo/modeling/__init__.py
+14
-0
official/vision/beta/projects/yolo/modeling/backbones/__init__.py
.../vision/beta/projects/yolo/modeling/backbones/__init__.py
+14
-0
official/vision/beta/projects/yolo/modeling/layers/__init__.py
...ial/vision/beta/projects/yolo/modeling/layers/__init__.py
+14
-0
official/vision/beta/projects/yolo/serving/__init__.py
official/vision/beta/projects/yolo/serving/__init__.py
+14
-0
official/vision/configs/semantic_segmentation.py
official/vision/configs/semantic_segmentation.py
+7
-0
official/vision/configs/video_classification.py
official/vision/configs/video_classification.py
+1
-0
No files found.
official/projects/deepmac_maskrcnn/README.md
View file @
57557684
...
...
@@ -107,7 +107,7 @@ SpienNet-143 | Hourglass-52 | `deep_mask_head_rcnn_voc_spinenet143_hg52.yaml` |
*
[
DeepMAC model
](
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/deepmac.md
)
in the Object Detection API code base.
*
Project website -
[
git.io/deepmac
](
https://g
it
.io/deepmac
)
*
Project website -
[
git.io/deepmac
](
https://g
oogle.github
.io/deepmac
/
)
## Citation
...
...
official/projects/movinet/configs/movinet.py
View file @
57557684
...
...
@@ -53,6 +53,7 @@ class Movinet(hyperparams.Config):
gating_activation
:
str
=
'sigmoid'
stochastic_depth_drop_rate
:
float
=
0.2
use_external_states
:
bool
=
False
average_pooling_type
:
str
=
'3d'
@
dataclasses
.
dataclass
...
...
official/projects/movinet/modeling/movinet.py
View file @
57557684
...
...
@@ -322,6 +322,7 @@ class Movinet(tf.keras.Model):
stochastic_depth_drop_rate
:
float
=
0.
,
use_external_states
:
bool
=
False
,
output_states
:
bool
=
True
,
average_pooling_type
:
str
=
'3d'
,
**
kwargs
):
"""MoViNet initialization function.
...
...
@@ -360,6 +361,8 @@ class Movinet(tf.keras.Model):
the model in streaming mode. Inputting the output states of the
previous input clip with the current input clip will utilize a stream
buffer for streaming video.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
**kwargs: keyword arguments to be passed.
"""
block_specs
=
BLOCK_SPECS
[
model_id
]
...
...
@@ -393,6 +396,7 @@ class Movinet(tf.keras.Model):
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_use_external_states
=
use_external_states
self
.
_output_states
=
output_states
self
.
_average_pooling_type
=
average_pooling_type
if
self
.
_use_external_states
and
not
self
.
_causal
:
raise
ValueError
(
'External states should be used with causal mode.'
)
...
...
@@ -520,6 +524,7 @@ class Movinet(tf.keras.Model):
batch_norm_layer
=
self
.
_norm
,
batch_norm_momentum
=
self
.
_norm_momentum
,
batch_norm_epsilon
=
self
.
_norm_epsilon
,
average_pooling_type
=
self
.
_average_pooling_type
,
state_prefix
=
'state_head'
,
name
=
'head'
)
x
,
states
=
layer_obj
(
x
,
states
=
states
)
...
...
@@ -730,4 +735,5 @@ def build_movinet(
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
,
stochastic_depth_drop_rate
=
backbone_cfg
.
stochastic_depth_drop_rate
,
use_external_states
=
backbone_cfg
.
use_external_states
)
use_external_states
=
backbone_cfg
.
use_external_states
,
average_pooling_type
=
backbone_cfg
.
average_pooling_type
)
official/projects/movinet/modeling/movinet_layers.py
View file @
57557684
...
...
@@ -802,12 +802,14 @@ class StreamSqueezeExcitation(tf.keras.layers.Layer):
states
=
dict
(
states
)
if
states
is
not
None
else
{}
if
self
.
_se_type
==
'3d'
:
x
,
states
=
self
.
_spatiotemporal_pool
(
inputs
,
states
=
states
)
x
,
states
=
self
.
_spatiotemporal_pool
(
inputs
,
states
=
states
,
output_states
=
True
)
elif
self
.
_se_type
==
'2d'
:
x
=
self
.
_spatial_pool
(
inputs
)
elif
self
.
_se_type
==
'2plus3d'
:
x_space
=
self
.
_spatial_pool
(
inputs
)
x
,
states
=
self
.
_spatiotemporal_pool
(
x_space
,
states
=
states
)
x
,
states
=
self
.
_spatiotemporal_pool
(
x_space
,
states
=
states
,
output_states
=
True
)
if
not
self
.
_causal
:
x
=
tf
.
tile
(
x
,
[
1
,
tf
.
shape
(
inputs
)[
1
],
1
,
1
,
1
])
...
...
@@ -1362,6 +1364,7 @@ class Head(tf.keras.layers.Layer):
tf
.
keras
.
layers
.
BatchNormalization
,
batch_norm_momentum
:
float
=
0.99
,
batch_norm_epsilon
:
float
=
1e-3
,
average_pooling_type
:
str
=
'3d'
,
state_prefix
:
Optional
[
str
]
=
None
,
# pytype: disable=annotation-type-mismatch # typed-keras
**
kwargs
):
"""Implementation for video model head.
...
...
@@ -1378,6 +1381,8 @@ class Head(tf.keras.layers.Layer):
batch_norm_layer: class to use for batch norm.
batch_norm_momentum: momentum of the batch norm operation.
batch_norm_epsilon: epsilon of the batch norm operation.
average_pooling_type: The average pooling type. Currently supporting
['3d', '2d', 'none'].
state_prefix: a prefix string to identify states.
**kwargs: keyword arguments to be passed to this layer.
"""
...
...
@@ -1404,8 +1409,16 @@ class Head(tf.keras.layers.Layer):
batch_norm_momentum
=
self
.
_batch_norm_momentum
,
batch_norm_epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'project'
)
if
average_pooling_type
.
lower
()
==
'3d'
:
self
.
_pool
=
nn_layers
.
GlobalAveragePool3D
(
keepdims
=
True
,
causal
=
False
,
state_prefix
=
state_prefix
)
elif
average_pooling_type
.
lower
()
==
'2d'
:
self
.
_pool
=
nn_layers
.
SpatialAveragePool3D
(
keepdims
=
True
)
elif
average_pooling_type
==
'none'
:
self
.
_pool
=
None
else
:
raise
ValueError
(
'%s average_pooling_type is not supported.'
%
average_pooling_type
)
def
get_config
(
self
):
"""Returns a dictionary containing the config used for initialization."""
...
...
@@ -1439,7 +1452,11 @@ class Head(tf.keras.layers.Layer):
"""
states
=
dict
(
states
)
if
states
is
not
None
else
{}
x
=
self
.
_project
(
inputs
)
return
self
.
_pool
(
x
,
states
=
states
)
if
self
.
_pool
is
not
None
:
outputs
=
self
.
_pool
(
x
,
states
=
states
,
output_states
=
True
)
else
:
outputs
=
x
return
outputs
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
...
...
official/projects/qat/vision/configs/common.py
View file @
57557684
...
...
@@ -30,8 +30,11 @@ class Quantization(hyperparams.Config):
change_num_bits: A `bool` indicates whether to manually allocate num_bits.
num_bits_weight: An `int` number of bits for weight. Default to 8.
num_bits_activation: An `int` number of bits for activation. Default to 8.
quantize_detection_head: A `bool` indicates whether to quantize detection
head. It only works for detection model.
"""
pretrained_original_checkpoint
:
Optional
[
str
]
=
None
change_num_bits
:
bool
=
False
num_bits_weight
:
int
=
8
num_bits_activation
:
int
=
8
quantize_detection_head
:
bool
=
False
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu.yaml
0 → 100644
View file @
57557684
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 24.7
# QAT only supports float32 tpu due to fake-quant op.
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
losses
:
l2_weight_decay
:
3.0e-05
model
:
anchor
:
anchor_size
:
3
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
num_scales
:
3
backbone
:
spinenet_mobile
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
49'
se_ratio
:
0.2
use_keras_upsampling_2d
:
true
type
:
'
spinenet_mobile'
decoder
:
type
:
'
identity'
head
:
num_convs
:
4
num_filters
:
48
use_separable_conv
:
true
input_size
:
[
384
,
384
,
3
]
max_level
:
7
min_level
:
3
norm_activation
:
activation
:
'
swish'
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
train_data
:
dtype
:
'
float32'
global_batch_size
:
128
is_training
:
true
parser
:
aug_rand_hflip
:
true
aug_scale_max
:
2.0
aug_scale_min
:
0.5
validation_data
:
dtype
:
'
float32'
global_batch_size
:
16
is_training
:
false
quantization
:
pretrained_original_checkpoint
:
'
gs://**/coco_spinenet49_mobile_tpu_33884721/ckpt-277200'
trainer
:
checkpoint_interval
:
924
optimizer_config
:
learning_rate
:
stepwise
:
boundaries
:
[
531300
,
545160
]
values
:
[
0.0016
,
0.00016
,
0.000016
]
type
:
'
stepwise'
warmup
:
linear
:
warmup_learning_rate
:
0.0000335
warmup_steps
:
4000
steps_per_loop
:
924
train_steps
:
554400
validation_interval
:
924
validation_steps
:
1250
summary_interval
:
924
official/projects/qat/vision/configs/experiments/retinanet/coco_spinenet49_mobile_qat_tpu_e2e.yaml
0 → 100644
View file @
57557684
# --experiment_type=retinanet_spinenet_mobile_coco_qat
# COCO mAP: 22.0
# QAT only supports float32 tpu due to fake-quant op.
runtime
:
distribution_strategy
:
'
tpu'
mixed_precision_dtype
:
'
float32'
task
:
losses
:
l2_weight_decay
:
3.0e-05
model
:
anchor
:
anchor_size
:
3
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
num_scales
:
3
backbone
:
spinenet_mobile
:
stochastic_depth_drop_rate
:
0.2
model_id
:
'
49'
se_ratio
:
0.2
use_keras_upsampling_2d
:
true
type
:
'
spinenet_mobile'
decoder
:
type
:
'
identity'
head
:
num_convs
:
4
num_filters
:
48
use_separable_conv
:
true
input_size
:
[
384
,
384
,
3
]
max_level
:
7
min_level
:
3
norm_activation
:
activation
:
'
swish'
norm_epsilon
:
0.001
norm_momentum
:
0.99
use_sync_bn
:
true
train_data
:
dtype
:
'
float32'
global_batch_size
:
256
is_training
:
true
parser
:
aug_rand_hflip
:
true
aug_scale_max
:
2.0
aug_scale_min
:
0.5
validation_data
:
dtype
:
'
float32'
global_batch_size
:
16
is_training
:
false
quantization
:
pretrained_original_checkpoint
:
'
gs://**/coco_spinenet49_mobile_tpu_33884721/ckpt-277200'
quantize_detection_head
:
true
trainer
:
checkpoint_interval
:
462
optimizer_config
:
learning_rate
:
stepwise
:
boundaries
:
[
263340
,
272580
]
values
:
[
0.032
,
0.0032
,
0.00032
]
type
:
'
stepwise'
warmup
:
linear
:
warmup_learning_rate
:
0.00067
warmup_steps
:
2000
steps_per_loop
:
462
train_steps
:
277200
validation_interval
:
462
validation_steps
:
625
summary_interval
:
924
official/projects/qat/vision/modeling/factory.py
View file @
57557684
...
...
@@ -20,12 +20,14 @@ import tensorflow as tf
import
tensorflow_model_optimization
as
tfmot
from
official.projects.qat.vision.configs
import
common
from
official.projects.qat.vision.modeling
import
segmentation_model
as
qat_segmentation_model
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
as
dense_prediction_heads_qat
from
official.projects.qat.vision.n_bit
import
schemes
as
n_bit_schemes
from
official.projects.qat.vision.quantization
import
schemes
from
official.vision
import
configs
from
official.vision.modeling
import
classification_model
from
official.vision.modeling
import
retinanet_model
from
official.vision.modeling.decoders
import
aspp
from
official.vision.modeling.heads
import
dense_prediction_heads
from
official.vision.modeling.heads
import
segmentation_heads
from
official.vision.modeling.layers
import
nn_layers
...
...
@@ -148,10 +150,17 @@ def build_qat_retinanet(
optimized_backbone
=
tfmot
.
quantization
.
keras
.
quantize_apply
(
annotated_backbone
,
scheme
=
schemes
.
Default8BitQuantizeScheme
())
head
=
model
.
head
if
quantization
.
quantize_detection_head
:
if
not
isinstance
(
head
,
dense_prediction_heads
.
RetinaNetHead
):
raise
ValueError
(
'Currently only supports RetinaNetHead.'
)
head
=
(
dense_prediction_heads_qat
.
RetinaNetHeadQuantized
.
from_config
(
head
.
get_config
()))
optimized_model
=
retinanet_model
.
RetinaNetModel
(
optimized_backbone
,
model
.
decoder
,
model
.
head
,
head
,
model
.
detection_generator
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
...
...
official/projects/qat/vision/modeling/heads/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Heads package definition."""
from
official.projects.qat.vision.modeling.heads.dense_prediction_heads
import
RetinaNetHeadQuantized
official/projects/qat/vision/modeling/heads/dense_prediction_heads.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of dense prediction heads."""
from
__future__
import
annotations
import
copy
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Union
,
Type
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.modeling
import
tf_utils
from
official.projects.qat.vision.quantization
import
configs
from
official.projects.qat.vision.quantization
import
helper
class
SeparableConv2DQuantized
(
tf
.
keras
.
layers
.
Layer
):
"""Quantized SeperableConv2D."""
def
__init__
(
self
,
name
:
Optional
[
str
]
=
None
,
last_quantize
:
bool
=
False
,
**
conv_kwargs
):
"""Initializes a SeparableConv2DQuantized.
Args:
name: The name of the layer.
last_quantize: A `bool` indicates whether add quantization for the output.
**conv_kwargs: A keyword arguments to be used for conv and dwconv.
"""
super
().
__init__
(
name
=
name
)
self
.
_conv_kwargs
=
copy
.
deepcopy
(
conv_kwargs
)
self
.
_name
=
name
self
.
_last_quantize
=
last_quantize
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]):
"""Creates the child layers of the layer."""
depthwise_conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
DepthwiseConv2D
,
configs
.
Default8BitConvQuantizeConfig
(
[
'depthwise_kernel'
],
[
'activation'
],
True
))
conv2d_quantized
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
Conv2D
,
configs
.
Default8BitConvQuantizeConfig
(
[
'kernel'
],
[],
self
.
_last_quantize
))
dwconv_kwargs
=
self
.
_conv_kwargs
.
copy
()
# Depthwise conv input filters is always equal to output filters.
# This filters argument only needed for the point-wise conv2d op.
del
dwconv_kwargs
[
'filters'
]
self
.
dw_conv
=
depthwise_conv2d_quantized
(
name
=
'dw'
,
**
dwconv_kwargs
)
conv_kwargs
=
self
.
_conv_kwargs
.
copy
()
conv_kwargs
.
update
({
'kernel_size'
:
(
1
,
1
),
'strides'
:
(
1
,
1
),
'padding'
:
'valid'
,
'groups'
:
1
,
})
self
.
conv
=
conv2d_quantized
(
name
=
'pw'
,
**
conv_kwargs
)
def
call
(
self
,
inputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Call the separable conv layer."""
x
=
self
.
dw_conv
(
inputs
)
outputs
=
self
.
conv
(
x
)
return
outputs
def
get_config
(
self
)
->
Dict
[
str
,
Any
]:
"""Returns the config of the layer."""
config
=
self
.
_conv_kwargs
.
copy
()
config
.
update
({
'name'
:
self
.
_name
,
'last_quantize'
:
self
.
_last_quantize
,
})
return
config
@
classmethod
def
from_config
(
cls
:
Type
[
SeparableConv2DQuantized
],
config
:
Dict
[
str
,
Any
])
->
SeparableConv2DQuantized
:
"""Creates a layer from its config."""
return
cls
(
**
config
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
RetinaNetHeadQuantized
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a RetinaNet quantized head."""
def
__init__
(
self
,
min_level
:
int
,
max_level
:
int
,
num_classes
:
int
,
num_anchors_per_location
:
int
,
num_convs
:
int
=
4
,
num_filters
:
int
=
256
,
attribute_heads
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_separable_conv
:
bool
=
False
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
num_params_per_anchor
:
int
=
4
,
**
kwargs
):
"""Initializes a RetinaNet quantized head.
Args:
min_level: An `int` number of minimum feature level.
max_level: An `int` number of maximum feature level.
num_classes: An `int` number of classes to predict.
num_anchors_per_location: An `int` number of number of anchors per pixel
location.
num_convs: An `int` number that represents the number of the intermediate
conv layers before the prediction.
num_filters: An `int` number that represents the number of filters of the
intermediate conv layers.
attribute_heads: If not None, a list that contains a dict for each
additional attribute head. Each dict consists of 3 key-value pairs:
`name`, `type` ('regression' or 'classification'), and `size` (number
of predicted values for each instance).
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: A `bool` that indicates whether to use synchronized batch
normalization across different replicas.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
num_params_per_anchor: Number of parameters required to specify an anchor
box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and
widths.
**kwargs: Additional keyword arguments to be passed.
"""
super
().
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'min_level'
:
min_level
,
'max_level'
:
max_level
,
'num_classes'
:
num_classes
,
'num_anchors_per_location'
:
num_anchors_per_location
,
'num_convs'
:
num_convs
,
'num_filters'
:
num_filters
,
'attribute_heads'
:
attribute_heads
,
'use_separable_conv'
:
use_separable_conv
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'num_params_per_anchor'
:
num_params_per_anchor
,
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation
=
tfmot
.
quantization
.
keras
.
QuantizeWrapperV2
(
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
),
configs
.
Default8BitActivationQuantizeConfig
())
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]):
"""Creates the variables of the head."""
if
self
.
_config_dict
[
'use_separable_conv'
]:
conv_op
=
SeparableConv2DQuantized
else
:
conv_op
=
helper
.
quantize_wrapped_layer
(
tf
.
keras
.
layers
.
Conv2D
,
configs
.
Default8BitConvQuantizeConfig
(
[
'kernel'
],
[
'activation'
],
False
))
conv_kwargs
=
{
'filters'
:
self
.
_config_dict
[
'num_filters'
],
'kernel_size'
:
3
,
'padding'
:
'same'
,
'bias_initializer'
:
tf
.
zeros_initializer
(),
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
not
self
.
_config_dict
[
'use_separable_conv'
]:
conv_kwargs
.
update
({
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
})
base_bn_op
=
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
if
self
.
_config_dict
[
'use_sync_bn'
]
else
tf
.
keras
.
layers
.
BatchNormalization
)
bn_op
=
helper
.
norm_by_activation
(
self
.
_config_dict
[
'activation'
],
helper
.
quantize_wrapped_layer
(
base_bn_op
,
configs
.
Default8BitOutputQuantizeConfig
()),
helper
.
quantize_wrapped_layer
(
base_bn_op
,
configs
.
NoOpQuantizeConfig
()))
bn_kwargs
=
{
'axis'
:
self
.
_bn_axis
,
'momentum'
:
self
.
_config_dict
[
'norm_momentum'
],
'epsilon'
:
self
.
_config_dict
[
'norm_epsilon'
],
}
# Class net.
self
.
_cls_convs
=
[]
self
.
_cls_norms
=
[]
for
level
in
range
(
self
.
_config_dict
[
'min_level'
],
self
.
_config_dict
[
'max_level'
]
+
1
):
this_level_cls_norms
=
[]
for
i
in
range
(
self
.
_config_dict
[
'num_convs'
]):
if
level
==
self
.
_config_dict
[
'min_level'
]:
cls_conv_name
=
'classnet-conv_{}'
.
format
(
i
)
self
.
_cls_convs
.
append
(
conv_op
(
name
=
cls_conv_name
,
**
conv_kwargs
))
cls_norm_name
=
'classnet-conv-norm_{}_{}'
.
format
(
level
,
i
)
this_level_cls_norms
.
append
(
bn_op
(
name
=
cls_norm_name
,
**
bn_kwargs
))
self
.
_cls_norms
.
append
(
this_level_cls_norms
)
classifier_kwargs
=
{
'filters'
:
(
self
.
_config_dict
[
'num_classes'
]
*
self
.
_config_dict
[
'num_anchors_per_location'
]),
'kernel_size'
:
3
,
'padding'
:
'same'
,
'bias_initializer'
:
tf
.
constant_initializer
(
-
np
.
log
((
1
-
0.01
)
/
0.01
)),
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
not
self
.
_config_dict
[
'use_separable_conv'
]:
classifier_kwargs
.
update
({
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
1e-5
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
})
self
.
_classifier
=
conv_op
(
name
=
'scores'
,
last_quantize
=
True
,
**
classifier_kwargs
)
# Box net.
self
.
_box_convs
=
[]
self
.
_box_norms
=
[]
for
level
in
range
(
self
.
_config_dict
[
'min_level'
],
self
.
_config_dict
[
'max_level'
]
+
1
):
this_level_box_norms
=
[]
for
i
in
range
(
self
.
_config_dict
[
'num_convs'
]):
if
level
==
self
.
_config_dict
[
'min_level'
]:
box_conv_name
=
'boxnet-conv_{}'
.
format
(
i
)
self
.
_box_convs
.
append
(
conv_op
(
name
=
box_conv_name
,
**
conv_kwargs
))
box_norm_name
=
'boxnet-conv-norm_{}_{}'
.
format
(
level
,
i
)
this_level_box_norms
.
append
(
bn_op
(
name
=
box_norm_name
,
**
bn_kwargs
))
self
.
_box_norms
.
append
(
this_level_box_norms
)
box_regressor_kwargs
=
{
'filters'
:
(
self
.
_config_dict
[
'num_params_per_anchor'
]
*
self
.
_config_dict
[
'num_anchors_per_location'
]),
'kernel_size'
:
3
,
'padding'
:
'same'
,
'bias_initializer'
:
tf
.
zeros_initializer
(),
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
not
self
.
_config_dict
[
'use_separable_conv'
]:
box_regressor_kwargs
.
update
({
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
1e-5
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
})
self
.
_box_regressor
=
conv_op
(
name
=
'boxes'
,
last_quantize
=
True
,
**
box_regressor_kwargs
)
# Attribute learning nets.
if
self
.
_config_dict
[
'attribute_heads'
]:
self
.
_att_predictors
=
{}
self
.
_att_convs
=
{}
self
.
_att_norms
=
{}
for
att_config
in
self
.
_config_dict
[
'attribute_heads'
]:
att_name
=
att_config
[
'name'
]
att_type
=
att_config
[
'type'
]
att_size
=
att_config
[
'size'
]
att_convs_i
=
[]
att_norms_i
=
[]
# Build conv and norm layers.
for
level
in
range
(
self
.
_config_dict
[
'min_level'
],
self
.
_config_dict
[
'max_level'
]
+
1
):
this_level_att_norms
=
[]
for
i
in
range
(
self
.
_config_dict
[
'num_convs'
]):
if
level
==
self
.
_config_dict
[
'min_level'
]:
att_conv_name
=
'{}-conv_{}'
.
format
(
att_name
,
i
)
att_convs_i
.
append
(
conv_op
(
name
=
att_conv_name
,
**
conv_kwargs
))
att_norm_name
=
'{}-conv-norm_{}_{}'
.
format
(
att_name
,
level
,
i
)
this_level_att_norms
.
append
(
bn_op
(
name
=
att_norm_name
,
**
bn_kwargs
))
att_norms_i
.
append
(
this_level_att_norms
)
self
.
_att_convs
[
att_name
]
=
att_convs_i
self
.
_att_norms
[
att_name
]
=
att_norms_i
# Build the final prediction layer.
att_predictor_kwargs
=
{
'filters'
:
(
att_size
*
self
.
_config_dict
[
'num_anchors_per_location'
]),
'kernel_size'
:
3
,
'padding'
:
'same'
,
'bias_initializer'
:
tf
.
zeros_initializer
(),
'bias_regularizer'
:
self
.
_config_dict
[
'bias_regularizer'
],
}
if
att_type
==
'regression'
:
att_predictor_kwargs
.
update
(
{
'bias_initializer'
:
tf
.
zeros_initializer
()})
elif
att_type
==
'classification'
:
att_predictor_kwargs
.
update
({
'bias_initializer'
:
tf
.
constant_initializer
(
-
np
.
log
((
1
-
0.01
)
/
0.01
))
})
else
:
raise
ValueError
(
'Attribute head type {} not supported.'
.
format
(
att_type
))
if
not
self
.
_config_dict
[
'use_separable_conv'
]:
att_predictor_kwargs
.
update
({
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
1e-5
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
})
self
.
_att_predictors
[
att_name
]
=
conv_op
(
name
=
'{}_attributes'
.
format
(
att_name
),
**
att_predictor_kwargs
)
super
().
build
(
input_shape
)
def
call
(
self
,
features
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Forward pass of the RetinaNet quantized head.
Args:
features: A `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
scores: A `dict` of `tf.Tensor` which includes scores of the predictions.
- key: A `str` of the level of the multilevel predictions.
- values: A `tf.Tensor` of the box scores predicted from a particular
feature level, whose shape is
[batch, height_l, width_l, num_classes * num_anchors_per_location].
boxes: A `dict` of `tf.Tensor` which includes coordinates of the
predictions.
- key: A `str` of the level of the multilevel predictions.
- values: A `tf.Tensor` of the box scores predicted from a particular
feature level, whose shape is
[batch, height_l, width_l,
num_params_per_anchor * num_anchors_per_location].
attributes: a dict of (attribute_name, attribute_prediction). Each
`attribute_prediction` is a dict of:
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the box scores predicted from a particular feature
level, whose shape is
[batch, height_l, width_l,
attribute_size * num_anchors_per_location].
Can be an empty dictionary if no attribute learning is required.
"""
scores
=
{}
boxes
=
{}
if
self
.
_config_dict
[
'attribute_heads'
]:
attributes
=
{
att_config
[
'name'
]:
{}
for
att_config
in
self
.
_config_dict
[
'attribute_heads'
]
}
else
:
attributes
=
{}
for
i
,
level
in
enumerate
(
range
(
self
.
_config_dict
[
'min_level'
],
self
.
_config_dict
[
'max_level'
]
+
1
)):
this_level_features
=
features
[
str
(
level
)]
# class net.
x
=
this_level_features
for
conv
,
norm
in
zip
(
self
.
_cls_convs
,
self
.
_cls_norms
[
i
]):
x
=
conv
(
x
)
x
=
norm
(
x
)
x
=
self
.
_activation
(
x
)
scores
[
str
(
level
)]
=
self
.
_classifier
(
x
)
# box net.
x
=
this_level_features
for
conv
,
norm
in
zip
(
self
.
_box_convs
,
self
.
_box_norms
[
i
]):
x
=
conv
(
x
)
x
=
norm
(
x
)
x
=
self
.
_activation
(
x
)
boxes
[
str
(
level
)]
=
self
.
_box_regressor
(
x
)
# attribute nets.
if
self
.
_config_dict
[
'attribute_heads'
]:
for
att_config
in
self
.
_config_dict
[
'attribute_heads'
]:
att_name
=
att_config
[
'name'
]
x
=
this_level_features
for
conv
,
norm
in
zip
(
self
.
_att_convs
[
att_name
],
self
.
_att_norms
[
att_name
][
i
]):
x
=
conv
(
x
)
x
=
norm
(
x
)
x
=
self
.
_activation
(
x
)
attributes
[
att_name
][
str
(
level
)]
=
self
.
_att_predictors
[
att_name
](
x
)
return
scores
,
boxes
,
attributes
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
official/projects/qat/vision/modeling/heads/dense_prediction_heads_test.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for dense_prediction_heads.py."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.projects.qat.vision.modeling.heads
import
dense_prediction_heads
class
RetinaNetHeadQuantizedTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
False
,
False
,
False
),
(
False
,
True
,
False
),
(
True
,
False
,
True
),
(
True
,
True
,
True
),
)
def
test_forward
(
self
,
use_separable_conv
,
use_sync_bn
,
has_att_heads
):
if
has_att_heads
:
attribute_heads
=
[
dict
(
name
=
'depth'
,
type
=
'regression'
,
size
=
1
)]
else
:
attribute_heads
=
None
retinanet_head
=
dense_prediction_heads
.
RetinaNetHeadQuantized
(
min_level
=
3
,
max_level
=
4
,
num_classes
=
3
,
num_anchors_per_location
=
3
,
num_convs
=
2
,
num_filters
=
256
,
attribute_heads
=
attribute_heads
,
use_separable_conv
=
use_separable_conv
,
activation
=
'relu'
,
use_sync_bn
=
use_sync_bn
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
)
features
=
{
'3'
:
np
.
random
.
rand
(
2
,
128
,
128
,
16
),
'4'
:
np
.
random
.
rand
(
2
,
64
,
64
,
16
),
}
scores
,
boxes
,
attributes
=
retinanet_head
(
features
)
self
.
assertAllEqual
(
scores
[
'3'
].
numpy
().
shape
,
[
2
,
128
,
128
,
9
])
self
.
assertAllEqual
(
scores
[
'4'
].
numpy
().
shape
,
[
2
,
64
,
64
,
9
])
self
.
assertAllEqual
(
boxes
[
'3'
].
numpy
().
shape
,
[
2
,
128
,
128
,
12
])
self
.
assertAllEqual
(
boxes
[
'4'
].
numpy
().
shape
,
[
2
,
64
,
64
,
12
])
if
has_att_heads
:
for
att
in
attributes
.
values
():
self
.
assertAllEqual
(
att
[
'3'
].
numpy
().
shape
,
[
2
,
128
,
128
,
3
])
self
.
assertAllEqual
(
att
[
'4'
].
numpy
().
shape
,
[
2
,
64
,
64
,
3
])
def
test_serialize_deserialize
(
self
):
retinanet_head
=
dense_prediction_heads
.
RetinaNetHeadQuantized
(
min_level
=
3
,
max_level
=
7
,
num_classes
=
3
,
num_anchors_per_location
=
9
,
num_convs
=
2
,
num_filters
=
16
,
attribute_heads
=
None
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
)
config
=
retinanet_head
.
get_config
()
new_retinanet_head
=
(
dense_prediction_heads
.
RetinaNetHead
.
from_config
(
config
))
self
.
assertAllEqual
(
retinanet_head
.
get_config
(),
new_retinanet_head
.
get_config
())
official/projects/video_ssl/dataloaders/video_ssl_input.py
View file @
57557684
...
...
@@ -128,6 +128,9 @@ def _process_image(image: tf.Tensor,
# Self-supervised pre-training augmentations.
if
is_training
and
is_ssl
:
if
zero_centering_image
:
image_1
=
0.5
*
(
image_1
+
1.0
)
image_2
=
0.5
*
(
image_2
+
1.0
)
# Temporally consistent color jittering.
image_1
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_1
)
image_2
=
video_ssl_preprocess_ops
.
random_color_jitter_3d
(
image_2
)
...
...
@@ -139,6 +142,8 @@ def _process_image(image: tf.Tensor,
image_2
=
video_ssl_preprocess_ops
.
random_solarization
(
image_2
)
image
=
tf
.
concat
([
image_1
,
image_2
],
axis
=
0
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
if
zero_centering_image
:
image
=
2
*
(
image
-
0.5
)
return
image
...
...
@@ -233,7 +238,8 @@ class Parser(video_input.Parser):
stride
=
self
.
_stride
,
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
)
crop_size
=
self
.
_crop_size
,
zero_centering_image
=
self
.
_zero_centering_image
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
...
...
@@ -255,7 +261,8 @@ class Parser(video_input.Parser):
num_test_clips
=
self
.
_num_test_clips
,
min_resize
=
self
.
_min_resize
,
crop_size
=
self
.
_crop_size
,
num_crops
=
self
.
_num_crops
)
num_crops
=
self
.
_num_crops
,
zero_centering_image
=
self
.
_zero_centering_image
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
features
=
{
'image'
:
image
}
...
...
official/utils/docs/build_tfm_api_docs.py
View file @
57557684
...
...
@@ -29,7 +29,10 @@ from absl import logging
import
tensorflow
as
tf
from
tensorflow_docs.api_generator
import
doc_controls
from
tensorflow_docs.api_generator
import
generate_lib
from
tensorflow_docs.api_generator
import
parser
from
tensorflow_docs.api_generator
import
public_api
from
tensorflow_docs.api_generator.pretty_docs
import
base_page
from
tensorflow_docs.api_generator.pretty_docs
import
function_page
import
tensorflow_models
as
tfm
...
...
@@ -52,6 +55,44 @@ PROJECT_SHORT_NAME = 'tfm'
PROJECT_FULL_NAME
=
'TensorFlow Modeling Library'
class
ExpFactoryInfo
(
function_page
.
FunctionPageInfo
):
"""Customize the page for the experiment factory."""
def
collect_docs
(
self
):
super
().
collect_docs
()
self
.
doc
.
docstring_parts
.
append
(
self
.
make_factory_options_table
())
def
make_factory_options_table
(
self
):
lines
=
[
''
,
'Allowed values for `exp_name`:'
,
''
,
# The indent is important here, it keeps the site's markdown parser
# from switching to HTML mode.
' <table>
\n
'
,
'<th><code>exp_name</code></th><th>Description</th>'
,
]
reference_resolver
=
self
.
parser_config
.
reference_resolver
api_tree
=
self
.
parser_config
.
api_tree
for
name
,
fn
in
sorted
(
tfm
.
core
.
exp_factory
.
_REGISTERED_CONFIGS
.
items
()):
# pylint: disable=protected-access
fn_api_node
=
api_tree
.
node_for_object
(
fn
)
if
fn_api_node
is
None
:
location
=
parser
.
get_defined_in
(
self
.
py_object
,
self
.
parser_config
)
link
=
base_page
.
small_source_link
(
location
,
name
)
else
:
link
=
reference_resolver
.
python_link
(
name
,
fn_api_node
.
full_name
)
doc
=
fn
.
__doc__
if
doc
:
doc
=
doc
.
splitlines
()[
0
]
else
:
doc
=
''
lines
.
append
(
f
'<tr><td>
{
link
}
</td><td>
{
doc
}
</td></tr>'
)
lines
.
append
(
'</table>'
)
return
'
\n
'
.
join
(
lines
)
def
hide_module_model_and_layer_methods
():
"""Hide methods and properties defined in the base classes of Keras layers.
...
...
@@ -103,6 +144,9 @@ def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
del
tfm
.
nlp
.
layers
.
MultiHeadAttention
del
tfm
.
nlp
.
layers
.
EinsumDense
doc_controls
.
set_custom_page_builder_cls
(
tfm
.
core
.
exp_factory
.
get_exp_config
,
ExpFactoryInfo
)
url_parts
=
code_url_prefix
.
strip
(
'/'
).
split
(
'/'
)
url_parts
=
url_parts
[:
url_parts
.
index
(
'tensorflow_models'
)]
url_parts
.
append
(
'official'
)
...
...
official/vision/beta/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/yolo/modeling/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/yolo/modeling/backbones/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/yolo/modeling/layers/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/beta/projects/yolo/serving/__init__.py
0 → 100644
View file @
57557684
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/configs/semantic_segmentation.py
View file @
57557684
...
...
@@ -116,6 +116,12 @@ class Evaluation(hyperparams.Config):
report_train_mean_iou
:
bool
=
True
# Turning this off can speed up training.
@
dataclasses
.
dataclass
class
ExportConfig
(
hyperparams
.
Config
):
# Whether to rescale the predicted mask to the original image size.
rescale_output
:
bool
=
False
@
dataclasses
.
dataclass
class
SemanticSegmentationTask
(
cfg
.
TaskConfig
):
"""The model config."""
...
...
@@ -131,6 +137,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
export_config
:
ExportConfig
=
ExportConfig
()
@
exp_factory
.
register_config_factory
(
'semantic_segmentation'
)
...
...
official/vision/configs/video_classification.py
View file @
57557684
...
...
@@ -49,6 +49,7 @@ class DataConfig(cfg.DataConfig):
cycle_length
:
int
=
10
drop_remainder
:
bool
=
True
min_image_size
:
int
=
256
zero_centering_image
:
bool
=
False
is_multilabel
:
bool
=
False
output_audio
:
bool
=
False
audio_feature
:
str
=
''
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment