Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c127d527
"vscode:/vscode.git/clone" did not exist on "5538e05cb14641a41e36ebf96cd3611e37229f3a"
Unverified
Commit
c127d527
authored
Feb 04, 2022
by
Srihari Humbarwadi
Committed by
GitHub
Feb 04, 2022
Browse files
Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling
parents
78657911
457bcb85
Changes
71
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
271 additions
and
50 deletions
+271
-50
official/projects/detr/tasks/detection_test.py
official/projects/detr/tasks/detection_test.py
+116
-0
official/projects/detr/train.py
official/projects/detr/train.py
+70
-0
official/projects/edgetpu/nlp/serving/export_tflite_squad.py
official/projects/edgetpu/nlp/serving/export_tflite_squad.py
+2
-1
official/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a0_k600_cpu_local.yaml
...jects/movinet/configs/yaml/movinet_a0_k600_cpu_local.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a0_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a0_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a1_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a1_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a1_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a1_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a2_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a2_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a2_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a2_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a3_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a3_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a3_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a3_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a4_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a4_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a4_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a4_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a5_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_a5_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_a5_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_a5_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_t0_k600_8x8.yaml
...al/projects/movinet/configs/yaml/movinet_t0_k600_8x8.yaml
+1
-0
official/projects/movinet/configs/yaml/movinet_t0_stream_k600_8x8.yaml
...ects/movinet/configs/yaml/movinet_t0_stream_k600_8x8.yaml
+1
-0
official/projects/movinet/modeling/movinet.py
official/projects/movinet/modeling/movinet.py
+7
-6
official/projects/movinet/modeling/movinet_layers.py
official/projects/movinet/modeling/movinet_layers.py
+61
-43
No files found.
official/projects/detr/tasks/detection_test.py
0 → 100644
View file @
c127d527
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detection."""
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.projects.detr
import
optimization
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.tasks
import
detection
_NUM_EXAMPLES
=
10
def
_gen_fn
():
h
=
np
.
random
.
randint
(
0
,
300
)
w
=
np
.
random
.
randint
(
0
,
300
)
num_boxes
=
np
.
random
.
randint
(
0
,
50
)
return
{
'image'
:
np
.
ones
(
shape
=
(
h
,
w
,
3
),
dtype
=
np
.
uint8
),
'image/id'
:
np
.
random
.
randint
(
0
,
100
),
'image/filename'
:
'test'
,
'objects'
:
{
'is_crowd'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
bool
),
'bbox'
:
np
.
ones
(
shape
=
(
num_boxes
,
4
),
dtype
=
np
.
float32
),
'label'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'id'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'area'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
}
}
def
_as_dataset
(
self
,
*
args
,
**
kwargs
):
del
args
del
kwargs
return
tf
.
data
.
Dataset
.
from_generator
(
lambda
:
(
_gen_fn
()
for
i
in
range
(
_NUM_EXAMPLES
)),
output_types
=
self
.
info
.
features
.
dtype
,
output_shapes
=
self
.
info
.
features
.
shape
,
)
class
DetectionTest
(
tf
.
test
.
TestCase
):
def
test_train_step
(
self
):
config
=
detr_cfg
.
DetectionConfig
(
num_encoder_layers
=
1
,
num_decoder_layers
=
1
,
train_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
True
,
global_batch_size
=
2
,
))
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_as_dataset
):
task
=
detection
.
DectectionTask
(
config
)
model
=
task
.
build_model
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
opt_cfg
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'detr_adamw'
,
'detr_adamw'
:
{
'weight_decay_rate'
:
1e-4
,
'global_clipnorm'
:
0.1
,
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
120000
],
'values'
:
[
0.0001
,
1.0e-05
]
}
},
})
optimizer
=
detection
.
DectectionTask
.
create_optimizer
(
opt_cfg
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
)
def
test_validation_step
(
self
):
config
=
detr_cfg
.
DetectionConfig
(
num_encoder_layers
=
1
,
num_decoder_layers
=
1
,
validation_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
False
,
global_batch_size
=
2
,
))
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_as_dataset
):
task
=
detection
.
DectectionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
(
training
=
False
)
dataset
=
task
.
build_inputs
(
config
.
validation_data
)
iterator
=
iter
(
dataset
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
)
state
=
task
.
aggregate_logs
(
step_outputs
=
logs
)
task
.
reduce_aggregated_logs
(
state
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/train.py
0 → 100644
View file @
c127d527
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
# pylint: disable=unused-import
from
official.projects.detr.configs
import
detr
from
official.projects.detr.tasks
import
detection
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/projects/edgetpu/nlp/serving/export_tflite_squad.py
View file @
c127d527
...
...
@@ -135,7 +135,8 @@ def main(argv: Sequence[str]) -> None:
checkpoint
=
tf
.
train
.
Checkpoint
(
**
checkpoint_dict
)
checkpoint
.
restore
(
FLAGS
.
model_checkpoint
).
assert_existing_objects_matched
()
model_for_serving
=
build_model_for_serving
(
model
)
model_for_serving
=
build_model_for_serving
(
model
,
FLAGS
.
sequence_length
,
FLAGS
.
batch_size
)
model_for_serving
.
summary
()
# TODO(b/194449109): Need to save the model to file and then convert tflite
...
...
official/projects/movinet/configs/yaml/movinet_a0_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.2
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a0_k600_cpu_local.yaml
View file @
c127d527
...
...
@@ -12,6 +12,7 @@ task:
norm_activation
:
use_sync_bn
:
false
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a0_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -24,6 +24,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.2
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a1_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a1_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -24,6 +24,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.2
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a2_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a2_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -24,6 +24,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a3_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a3_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -25,6 +25,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a4_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a4_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -25,6 +25,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a5_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_a5_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -25,6 +25,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.5
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_t0_k600_8x8.yaml
View file @
c127d527
...
...
@@ -18,6 +18,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.2
activation
:
'
swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/configs/yaml/movinet_t0_stream_k600_8x8.yaml
View file @
c127d527
...
...
@@ -24,6 +24,7 @@ task:
norm_activation
:
use_sync_bn
:
true
dropout_rate
:
0.2
activation
:
'
hard_swish'
train_data
:
name
:
kinetics600
variant_name
:
rgb
...
...
official/projects/movinet/modeling/movinet.py
View file @
c127d527
...
...
@@ -338,7 +338,7 @@ class Movinet(tf.keras.Model):
3x3 followed by 5x1 conv). '3d_2plus1d' uses (2+1)D convolution with
Conv3D and no 2D reshaping (e.g., a 5x3x3 kernel becomes 1x3x3 followed
by 5x1x1 conv).
se_type: '3d', '2d',
or
'2plus3d'. '3d' uses the default 3D
se_type: '3d', '2d', '2plus3d'
or 'none'
. '3d' uses the default 3D
spatiotemporal global average pooling for squeeze excitation. '2d'
uses 2D spatial global average pooling on each frame. '2plus3d'
concatenates both 3D and 2D global average pooling.
...
...
@@ -369,7 +369,7 @@ class Movinet(tf.keras.Model):
if
conv_type
not
in
(
'3d'
,
'2plus1d'
,
'3d_2plus1d'
):
raise
ValueError
(
'Unknown conv type: {}'
.
format
(
conv_type
))
if
se_type
not
in
(
'3d'
,
'2d'
,
'2plus3d'
):
if
se_type
not
in
(
'3d'
,
'2d'
,
'2plus3d'
,
'none'
):
raise
ValueError
(
'Unknown squeeze excitation type: {}'
.
format
(
se_type
))
self
.
_model_id
=
model_id
...
...
@@ -602,10 +602,11 @@ class Movinet(tf.keras.Model):
expand_filters
,
)
states
[
f
'
{
prefix
}
_pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
)
states
[
f
'
{
prefix
}
_pool_frame_count'
]
=
(
1
,)
if
'3d'
in
self
.
_se_type
:
states
[
f
'
{
prefix
}
_pool_buffer'
]
=
(
input_shape
[
0
],
1
,
1
,
1
,
expand_filters
,
)
states
[
f
'
{
prefix
}
_pool_frame_count'
]
=
(
1
,)
if
use_positional_encoding
:
name
=
f
'
{
prefix
}
_pos_enc_frame_count'
...
...
official/projects/movinet/modeling/movinet_layers.py
View file @
c127d527
...
...
@@ -93,10 +93,9 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format
:
Optional
[
str
]
=
None
,
dilation_rate
:
Union
[
int
,
Sequence
[
int
]]
=
(
1
,
1
),
groups
:
int
=
1
,
activation
:
Optional
[
nn_layers
.
Activation
]
=
None
,
use_bias
:
bool
=
True
,
kernel_initializer
:
t
f
.
keras
.
initializers
.
Initialize
r
=
'glorot_uniform'
,
bias_initializer
:
t
f
.
keras
.
initializers
.
Initialize
r
=
'zeros'
,
kernel_initializer
:
s
tr
=
'glorot_uniform'
,
bias_initializer
:
s
tr
=
'zeros'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
activity_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
...
...
@@ -105,6 +104,8 @@ class MobileConv2D(tf.keras.layers.Layer):
use_depthwise
:
bool
=
False
,
use_temporal
:
bool
=
False
,
use_buffered_input
:
bool
=
False
,
# pytype: disable=annotation-type-mismatch # typed-keras
batch_norm_op
:
Optional
[
Any
]
=
None
,
activation_op
:
Optional
[
Any
]
=
None
,
**
kwargs
):
# pylint: disable=g-doc-args
"""Initializes mobile conv2d.
...
...
@@ -117,6 +118,10 @@ class MobileConv2D(tf.keras.layers.Layer):
use_buffered_input: if True, the input is expected to be padded
beforehand. In effect, calling this layer will use 'valid' padding on
the temporal dimension to simulate 'causal' padding.
batch_norm_op: A callable object of batch norm layer. If None, no batch
norm will be applied after the convolution.
activation_op: A callabel object of activation layer. If None, no
activation will be applied after the convolution.
**kwargs: keyword arguments to be passed to this layer.
Returns:
...
...
@@ -130,7 +135,6 @@ class MobileConv2D(tf.keras.layers.Layer):
self
.
_data_format
=
data_format
self
.
_dilation_rate
=
dilation_rate
self
.
_groups
=
groups
self
.
_activation
=
activation
self
.
_use_bias
=
use_bias
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
...
...
@@ -142,6 +146,8 @@ class MobileConv2D(tf.keras.layers.Layer):
self
.
_use_depthwise
=
use_depthwise
self
.
_use_temporal
=
use_temporal
self
.
_use_buffered_input
=
use_buffered_input
self
.
_batch_norm_op
=
batch_norm_op
self
.
_activation_op
=
activation_op
kernel_size
=
normalize_tuple
(
kernel_size
,
2
,
'kernel_size'
)
...
...
@@ -156,7 +162,6 @@ class MobileConv2D(tf.keras.layers.Layer):
depth_multiplier
=
1
,
data_format
=
data_format
,
dilation_rate
=
dilation_rate
,
activation
=
activation
,
use_bias
=
use_bias
,
depthwise_initializer
=
kernel_initializer
,
bias_initializer
=
bias_initializer
,
...
...
@@ -175,7 +180,6 @@ class MobileConv2D(tf.keras.layers.Layer):
data_format
=
data_format
,
dilation_rate
=
dilation_rate
,
groups
=
groups
,
activation
=
activation
,
use_bias
=
use_bias
,
kernel_initializer
=
kernel_initializer
,
bias_initializer
=
bias_initializer
,
...
...
@@ -196,7 +200,6 @@ class MobileConv2D(tf.keras.layers.Layer):
'data_format'
:
self
.
_data_format
,
'dilation_rate'
:
self
.
_dilation_rate
,
'groups'
:
self
.
_groups
,
'activation'
:
self
.
_activation
,
'use_bias'
:
self
.
_use_bias
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'bias_initializer'
:
self
.
_bias_initializer
,
...
...
@@ -229,6 +232,10 @@ class MobileConv2D(tf.keras.layers.Layer):
x
=
tf
.
reshape
(
inputs
,
input_shape
)
x
=
self
.
_conv
(
x
)
if
self
.
_batch_norm_op
is
not
None
:
x
=
self
.
_batch_norm_op
(
x
)
if
self
.
_activation_op
is
not
None
:
x
=
self
.
_activation_op
(
x
)
if
self
.
_use_temporal
:
output_shape
=
[
...
...
@@ -357,8 +364,20 @@ class ConvBlock(tf.keras.layers.Layer):
padding
=
'causal'
if
self
.
_causal
else
'same'
self
.
_groups
=
input_shape
[
-
1
]
if
self
.
_depthwise
else
1
self
.
_conv_temporal
=
None
self
.
_batch_norm
=
None
self
.
_batch_norm_temporal
=
None
if
self
.
_use_batch_norm
:
self
.
_batch_norm
=
self
.
_batch_norm_layer
(
momentum
=
self
.
_batch_norm_momentum
,
epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'bn'
)
if
self
.
_conv_type
!=
'3d'
and
self
.
_kernel_size
[
0
]
>
1
:
self
.
_batch_norm_temporal
=
self
.
_batch_norm_layer
(
momentum
=
self
.
_batch_norm_momentum
,
epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'bn_temporal'
)
self
.
_conv_temporal
=
None
if
self
.
_conv_type
==
'3d_2plus1d'
and
self
.
_kernel_size
[
0
]
>
1
:
self
.
_conv
=
nn_layers
.
Conv3D
(
self
.
_filters
,
...
...
@@ -394,6 +413,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
use_buffered_input
=
False
,
batch_norm_op
=
self
.
_batch_norm
,
activation_op
=
self
.
_activation_layer
,
name
=
'conv2d'
)
if
self
.
_kernel_size
[
0
]
>
1
:
self
.
_conv_temporal
=
MobileConv2D
(
...
...
@@ -408,6 +429,8 @@ class ConvBlock(tf.keras.layers.Layer):
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
use_buffered_input
=
self
.
_use_buffered_input
,
batch_norm_op
=
self
.
_batch_norm_temporal
,
activation_op
=
self
.
_activation_layer
,
name
=
'conv2d_temporal'
)
else
:
self
.
_conv
=
nn_layers
.
Conv3D
(
...
...
@@ -422,37 +445,26 @@ class ConvBlock(tf.keras.layers.Layer):
use_buffered_input
=
self
.
_use_buffered_input
,
name
=
'conv3d'
)
self
.
_batch_norm
=
None
self
.
_batch_norm_temporal
=
None
if
self
.
_use_batch_norm
:
self
.
_batch_norm
=
self
.
_batch_norm_layer
(
momentum
=
self
.
_batch_norm_momentum
,
epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'bn'
)
if
self
.
_conv_type
!=
'3d'
and
self
.
_conv_temporal
is
not
None
:
self
.
_batch_norm_temporal
=
self
.
_batch_norm_layer
(
momentum
=
self
.
_batch_norm_momentum
,
epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'bn_temporal'
)
super
(
ConvBlock
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
):
"""Calls the layer with the given inputs."""
x
=
inputs
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x
=
self
.
_conv
(
x
)
if
self
.
_batch_norm
is
not
None
:
if
self
.
_batch_norm
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_batch_norm
(
x
)
if
self
.
_activation_layer
is
not
None
:
if
self
.
_activation_layer
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_activation_layer
(
x
)
if
self
.
_conv_temporal
is
not
None
:
x
=
self
.
_conv_temporal
(
x
)
if
self
.
_batch_norm_temporal
is
not
None
:
if
self
.
_batch_norm_temporal
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_batch_norm_temporal
(
x
)
if
self
.
_activation_layer
is
not
None
:
if
self
.
_activation_layer
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_activation_layer
(
x
)
return
x
...
...
@@ -640,10 +652,13 @@ class StreamConvBlock(ConvBlock):
if
self
.
_conv_temporal
is
None
and
self
.
_stream_buffer
is
not
None
:
x
,
states
=
self
.
_stream_buffer
(
x
,
states
=
states
)
# bn_op and activation_op are folded into the '2plus1d' conv layer so that
# we do not explicitly call them here.
# TODO(lzyuan): clean the conv layers api once the models are re-trained.
x
=
self
.
_conv
(
x
)
if
self
.
_batch_norm
is
not
None
:
if
self
.
_batch_norm
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_batch_norm
(
x
)
if
self
.
_activation_layer
is
not
None
:
if
self
.
_activation_layer
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_activation_layer
(
x
)
if
self
.
_conv_temporal
is
not
None
:
...
...
@@ -653,9 +668,9 @@ class StreamConvBlock(ConvBlock):
x
,
states
=
self
.
_stream_buffer
(
x
,
states
=
states
)
x
=
self
.
_conv_temporal
(
x
)
if
self
.
_batch_norm_temporal
is
not
None
:
if
self
.
_batch_norm_temporal
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_batch_norm_temporal
(
x
)
if
self
.
_activation_layer
is
not
None
:
if
self
.
_activation_layer
is
not
None
and
self
.
_conv_type
!=
'2plus1d'
:
x
=
self
.
_activation_layer
(
x
)
return
x
,
states
...
...
@@ -885,7 +900,8 @@ class MobileBottleneck(tf.keras.layers.Layer):
x
=
self
.
_expansion_layer
(
inputs
)
x
,
states
=
self
.
_feature_layer
(
x
,
states
=
states
)
x
,
states
=
self
.
_attention_layer
(
x
,
states
=
states
)
if
self
.
_attention_layer
is
not
None
:
x
,
states
=
self
.
_attention_layer
(
x
,
states
=
states
)
x
=
self
.
_projection_layer
(
x
)
# Add identity so that the ops are ordered as written. This is useful for,
...
...
@@ -1136,18 +1152,20 @@ class MovinetBlock(tf.keras.layers.Layer):
batch_norm_momentum
=
self
.
_batch_norm_momentum
,
batch_norm_epsilon
=
self
.
_batch_norm_epsilon
,
name
=
'projection'
)
self
.
_attention
=
StreamSqueezeExcitation
(
se_hidden_filters
,
se_type
=
se_type
,
activation
=
activation
,
gating_activation
=
gating_activation
,
causal
=
self
.
_causal
,
conv_type
=
conv_type
,
use_positional_encoding
=
use_positional_encoding
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
state_prefix
=
state_prefix
,
name
=
'se'
)
self
.
_attention
=
None
if
se_type
!=
'none'
:
self
.
_attention
=
StreamSqueezeExcitation
(
se_hidden_filters
,
se_type
=
se_type
,
activation
=
activation
,
gating_activation
=
gating_activation
,
causal
=
self
.
_causal
,
conv_type
=
conv_type
,
use_positional_encoding
=
use_positional_encoding
,
kernel_initializer
=
kernel_initializer
,
kernel_regularizer
=
kernel_regularizer
,
state_prefix
=
state_prefix
,
name
=
'se'
)
def
get_config
(
self
):
"""Returns a dictionary containing the config used for initialization."""
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment