Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e9355843
Commit
e9355843
authored
Oct 08, 2021
by
A. Unique TensorFlower
Committed by
saberkun
Oct 08, 2021
Browse files
Internal change
PiperOrigin-RevId: 401839863
parent
8bfa4d03
Changes
42
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
325 additions
and
115 deletions
+325
-115
official/projects/volumetric_models/modeling/factory.py
official/projects/volumetric_models/modeling/factory.py
+2
-2
official/projects/volumetric_models/modeling/factory_test.py
official/projects/volumetric_models/modeling/factory_test.py
+4
-4
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
...volumetric_models/modeling/heads/segmentation_heads_3d.py
+0
-0
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py
...etric_models/modeling/heads/segmentation_heads_3d_test.py
+1
-1
official/projects/volumetric_models/modeling/nn_blocks_3d.py
official/projects/volumetric_models/modeling/nn_blocks_3d.py
+0
-0
official/projects/volumetric_models/modeling/nn_blocks_3d_test.py
.../projects/volumetric_models/modeling/nn_blocks_3d_test.py
+1
-1
official/projects/volumetric_models/modeling/segmentation_model_test.py
...cts/volumetric_models/modeling/segmentation_model_test.py
+3
-4
official/projects/volumetric_models/registry_imports.py
official/projects/volumetric_models/registry_imports.py
+4
-4
official/projects/volumetric_models/serving/export_saved_model.py
.../projects/volumetric_models/serving/export_saved_model.py
+1
-1
official/projects/volumetric_models/serving/semantic_segmentation_3d.py
...cts/volumetric_models/serving/semantic_segmentation_3d.py
+3
-3
official/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
...olumetric_models/serving/semantic_segmentation_3d_test.py
+4
-4
official/projects/volumetric_models/tasks/semantic_segmentation_3d.py
...jects/volumetric_models/tasks/semantic_segmentation_3d.py
+5
-5
official/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
.../volumetric_models/tasks/semantic_segmentation_3d_test.py
+4
-4
official/projects/volumetric_models/train.py
official/projects/volumetric_models/train.py
+1
-2
official/projects/volumetric_models/train_test.py
official/projects/volumetric_models/train_test.py
+1
-2
official/vision/beta/projects/yt8m/configs/yt8m.py
official/vision/beta/projects/yt8m/configs/yt8m.py
+52
-15
official/vision/beta/projects/yt8m/modeling/yt8m_agg_models.py
...ial/vision/beta/projects/yt8m/modeling/yt8m_agg_models.py
+27
-9
official/vision/beta/projects/yt8m/modeling/yt8m_model.py
official/vision/beta/projects/yt8m/modeling/yt8m_model.py
+96
-47
official/vision/beta/projects/yt8m/modeling/yt8m_model_test.py
...ial/vision/beta/projects/yt8m/modeling/yt8m_model_test.py
+4
-4
official/vision/beta/projects/yt8m/modeling/yt8m_model_utils.py
...al/vision/beta/projects/yt8m/modeling/yt8m_model_utils.py
+112
-3
No files found.
official/
vision/beta/
projects/volumetric_models/modeling/factory.py
→
official/projects/volumetric_models/modeling/factory.py
View file @
e9355843
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.projects.volumetric_models.modeling.decoders
import
factory
as
decoder_factory
from
official.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.projects.volumetric_models.modeling.decoders
import
factory
as
decoder_factory
from
official.vision.beta.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
def
build_segmentation_model_3d
(
def
build_segmentation_model_3d
(
...
...
official/
vision/beta/
projects/volumetric_models/modeling/factory_test.py
→
official/projects/volumetric_models/modeling/factory_test.py
View file @
e9355843
...
@@ -18,10 +18,10 @@ from absl.testing import parameterized
...
@@ -18,10 +18,10 @@ from absl.testing import parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.
vision.beta.
projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.
vision.beta.
projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
backbones
from
official.
vision.beta.
projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.modeling
import
decoders
from
official.
vision.beta.
projects.volumetric_models.modeling
import
factory
from
official.projects.volumetric_models.modeling
import
factory
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
→
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
View file @
e9355843
File moved
official/
vision/beta/
projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py
→
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py
View file @
e9355843
...
@@ -19,7 +19,7 @@ from absl.testing import parameterized
...
@@ -19,7 +19,7 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.
vision.beta.
projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
from
official.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
class
SegmentationHead3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SegmentationHead3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/modeling/nn_blocks_3d.py
→
official/projects/volumetric_models/modeling/nn_blocks_3d.py
View file @
e9355843
File moved
official/
vision/beta/
projects/volumetric_models/modeling/nn_blocks_3d_test.py
→
official/projects/volumetric_models/modeling/nn_blocks_3d_test.py
View file @
e9355843
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.
vision.beta.
projects.volumetric_models.modeling
import
nn_blocks_3d
from
official.projects.volumetric_models.modeling
import
nn_blocks_3d
class
NNBlocks3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
NNBlocks3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/modeling/segmentation_model_test.py
→
official/projects/volumetric_models/modeling/segmentation_model_test.py
View file @
e9355843
...
@@ -18,11 +18,10 @@
...
@@ -18,11 +18,10 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
class
SegmentationNetworkUNet3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SegmentationNetworkUNet3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/registry_imports.py
→
official/projects/volumetric_models/registry_imports.py
View file @
e9355843
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
registry_imports
from
official.
vision.beta.
projects.volumetric_models.configs
import
semantic_segmentation_3d
as
semantic_segmentation_3d_cfg
from
official.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
semantic_segmentation_3d_cfg
from
official.
vision.beta.
projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
backbones
from
official.
vision.beta.
projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.modeling
import
decoders
from
official.
vision.beta.
projects.volumetric_models.tasks
import
semantic_segmentation_3d
from
official.projects.volumetric_models.tasks
import
semantic_segmentation_3d
official/
vision/beta/
projects/volumetric_models/serving/export_saved_model.py
→
official/projects/volumetric_models/serving/export_saved_model.py
View file @
e9355843
...
@@ -42,7 +42,7 @@ from absl import flags
...
@@ -42,7 +42,7 @@ from absl import flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.
vision.beta.
projects.volumetric_models.serving
import
semantic_segmentation_3d
from
official.projects.volumetric_models.serving
import
semantic_segmentation_3d
from
official.vision.beta.serving
import
export_saved_model_lib
from
official.vision.beta.serving
import
export_saved_model_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/
vision/beta/
projects/volumetric_models/serving/semantic_segmentation_3d.py
→
official/projects/volumetric_models/serving/semantic_segmentation_3d.py
View file @
e9355843
...
@@ -19,9 +19,9 @@ from typing import Mapping
...
@@ -19,9 +19,9 @@ from typing import Mapping
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.
vision.beta.
projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
backbones
from
official.
vision.beta.
projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.modeling
import
decoders
from
official.
vision.beta.
projects.volumetric_models.modeling
import
factory
from
official.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.serving
import
export_base
from
official.vision.beta.serving
import
export_base
...
...
official/
vision/beta/
projects/volumetric_models/serving/semantic_segmentation_3d_test.py
→
official/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
View file @
e9355843
...
@@ -22,10 +22,10 @@ import tensorflow as tf
...
@@ -22,10 +22,10 @@ import tensorflow as tf
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.
vision.beta.
projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.
vision.beta.
projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
backbones
from
official.
vision.beta.
projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.modeling
import
decoders
from
official.
vision.beta.
projects.volumetric_models.serving
import
semantic_segmentation_3d
from
official.projects.volumetric_models.serving
import
semantic_segmentation_3d
class
SemanticSegmentationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
SemanticSegmentationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/tasks/semantic_segmentation_3d.py
→
official/projects/volumetric_models/tasks/semantic_segmentation_3d.py
View file @
e9355843
...
@@ -23,11 +23,11 @@ from official.common import dataset_fn
...
@@ -23,11 +23,11 @@ from official.common import dataset_fn
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.
vision.beta.
projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.
vision.beta.
projects.volumetric_models.dataloaders
import
segmentation_input_3d
from
official.projects.volumetric_models.dataloaders
import
segmentation_input_3d
from
official.
vision.beta.
projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.
vision.beta.
projects.volumetric_models.losses
import
segmentation_losses
from
official.projects.volumetric_models.losses
import
segmentation_losses
from
official.
vision.beta.
projects.volumetric_models.modeling
import
factory
from
official.projects.volumetric_models.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
SemanticSegmentation3DTask
)
@
task_factory
.
register_task_cls
(
exp_cfg
.
SemanticSegmentation3DTask
)
...
...
official/
vision/beta/
projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
→
official/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
View file @
e9355843
...
@@ -26,11 +26,11 @@ import tensorflow as tf
...
@@ -26,11 +26,11 @@ import tensorflow as tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.projects.volumetric_models.modeling
import
backbones
from
official.projects.volumetric_models.modeling
import
decoders
from
official.projects.volumetric_models.tasks
import
semantic_segmentation_3d
as
img_seg_task
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
as
img_seg_task
class
SemanticSegmentationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
SemanticSegmentationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
official/
vision/beta/
projects/volumetric_models/train.py
→
official/projects/volumetric_models/train.py
View file @
e9355843
...
@@ -15,12 +15,11 @@
...
@@ -15,12 +15,11 @@
"""TensorFlow Model Garden Vision training driver."""
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
app
import
gin
# pylint: disable=unused-import
import
gin
# pylint: disable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.projects.volumetric_models
import
registry_imports
# pylint: disable=unused-import
from
official.vision.beta
import
train
from
official.vision.beta
import
train
from
official.vision.beta.projects.volumetric_models
import
registry_imports
# pylint: disable=unused-import
def
main
(
_
):
def
main
(
_
):
...
...
official/
vision/beta/
projects/volumetric_models/train_test.py
→
official/projects/volumetric_models/train_test.py
View file @
e9355843
...
@@ -20,9 +20,8 @@ from absl import flags
...
@@ -20,9 +20,8 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.projects.volumetric_models
import
train
as
train_lib
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models
import
train
as
train_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/vision/beta/projects/yt8m/configs/yt8m.py
View file @
e9355843
...
@@ -13,14 +13,15 @@
...
@@ -13,14 +13,15 @@
# limitations under the License.
# limitations under the License.
"""Video classification configuration definition."""
"""Video classification configuration definition."""
import
dataclasses
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
from
absl
import
flags
from
absl
import
flags
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.vision.beta.configs
import
common
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -66,16 +67,28 @@ def yt8m(is_training):
...
@@ -66,16 +67,28 @@ def yt8m(is_training):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
YT8M
Model
(
hyperparams
.
Config
):
class
Moe
Model
(
hyperparams
.
Config
):
"""The model config."""
"""The model config."""
cluster_size
:
int
=
2048
num_mixtures
:
int
=
5
hidden_size
:
int
=
2048
l2_penalty
:
float
=
1e-5
use_input_context_gate
:
bool
=
False
use_output_context_gate
:
bool
=
False
@
dataclasses
.
dataclass
class
DbofModel
(
hyperparams
.
Config
):
"""The model config."""
cluster_size
:
int
=
3000
hidden_size
:
int
=
2000
add_batch_norm
:
bool
=
True
add_batch_norm
:
bool
=
True
sample_random_frames
:
bool
=
True
sample_random_frames
:
bool
=
True
is_training
:
bool
=
Tru
e
use_context_gate_cluster_layer
:
bool
=
Fals
e
activation
:
str
=
'relu6'
context_gate_cluster_bottleneck_size
:
int
=
0
pooling_method
:
str
=
'average'
pooling_method
:
str
=
'average'
yt8m_agg_classifier_model
:
str
=
'MoeModel'
yt8m_agg_classifier_model
:
str
=
'MoeModel'
agg_model
:
hyperparams
.
Config
=
MoeModel
()
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
activation
=
'relu'
,
use_sync_bn
=
False
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -83,12 +96,13 @@ class Losses(hyperparams.Config):
...
@@ -83,12 +96,13 @@ class Losses(hyperparams.Config):
name
:
str
=
'binary_crossentropy'
name
:
str
=
'binary_crossentropy'
from_logits
:
bool
=
False
from_logits
:
bool
=
False
label_smoothing
:
float
=
0.0
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
1e-5
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
YT8MTask
(
cfg
.
TaskConfig
):
class
YT8MTask
(
cfg
.
TaskConfig
):
"""The task config."""
"""The task config."""
model
:
YT8M
Model
=
YT8M
Model
()
model
:
Dbof
Model
=
Dbof
Model
()
train_data
:
DataConfig
=
yt8m
(
is_training
=
True
)
train_data
:
DataConfig
=
yt8m
(
is_training
=
True
)
validation_data
:
DataConfig
=
yt8m
(
is_training
=
False
)
validation_data
:
DataConfig
=
yt8m
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
losses
:
Losses
=
Losses
()
...
@@ -102,8 +116,8 @@ def add_trainer(
...
@@ -102,8 +116,8 @@ def add_trainer(
experiment
:
cfg
.
ExperimentConfig
,
experiment
:
cfg
.
ExperimentConfig
,
train_batch_size
:
int
,
train_batch_size
:
int
,
eval_batch_size
:
int
,
eval_batch_size
:
int
,
learning_rate
:
float
=
0.00
5
,
learning_rate
:
float
=
0.00
01
,
train_epochs
:
int
=
44
,
train_epochs
:
int
=
50
,
):
):
"""Add and config a trainer to the experiment config."""
"""Add and config a trainer to the experiment config."""
if
YT8M_TRAIN_EXAMPLES
<=
0
:
if
YT8M_TRAIN_EXAMPLES
<=
0
:
...
@@ -115,13 +129,14 @@ def add_trainer(
...
@@ -115,13 +129,14 @@ def add_trainer(
experiment
.
task
.
train_data
.
global_batch_size
=
train_batch_size
experiment
.
task
.
train_data
.
global_batch_size
=
train_batch_size
experiment
.
task
.
validation_data
.
global_batch_size
=
eval_batch_size
experiment
.
task
.
validation_data
.
global_batch_size
=
eval_batch_size
steps_per_epoch
=
YT8M_TRAIN_EXAMPLES
//
train_batch_size
steps_per_epoch
=
YT8M_TRAIN_EXAMPLES
//
train_batch_size
steps_per_loop
=
30
experiment
.
trainer
=
cfg
.
TrainerConfig
(
experiment
.
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_
epoch
,
steps_per_loop
=
steps_per_
loop
,
summary_interval
=
steps_per_
epoch
,
summary_interval
=
steps_per_
loop
,
checkpoint_interval
=
steps_per_
epoch
,
checkpoint_interval
=
steps_per_
loop
,
train_steps
=
train_epochs
*
steps_per_epoch
,
train_steps
=
train_epochs
*
steps_per_epoch
,
validation_steps
=
YT8M_VAL_EXAMPLES
//
eval_batch_size
,
validation_steps
=
YT8M_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_
epoch
,
validation_interval
=
steps_per_
loop
,
optimizer_config
=
optimization
.
OptimizationConfig
({
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'adam'
,
'type'
:
'adam'
,
...
@@ -132,9 +147,18 @@ def add_trainer(
...
@@ -132,9 +147,18 @@ def add_trainer(
'exponential'
:
{
'exponential'
:
{
'initial_learning_rate'
:
learning_rate
,
'initial_learning_rate'
:
learning_rate
,
'decay_rate'
:
0.95
,
'decay_rate'
:
0.95
,
'decay_steps'
:
1500000
,
'decay_steps'
:
int
(
steps_per_epoch
*
1.5
),
'offset'
:
500
,
}
}
},
},
'warmup'
:
{
'linear'
:
{
'name'
:
'linear'
,
'warmup_learning_rate'
:
0
,
'warmup_steps'
:
500
,
},
'type'
:
'linear'
,
}
}))
}))
return
experiment
return
experiment
...
@@ -154,4 +178,17 @@ def yt8m_experiment() -> cfg.ExperimentConfig:
...
@@ -154,4 +178,17 @@ def yt8m_experiment() -> cfg.ExperimentConfig:
'task.train_data.feature_names != None'
,
'task.train_data.feature_names != None'
,
])
])
return
add_trainer
(
exp_config
,
train_batch_size
=
512
,
eval_batch_size
=
512
)
# Per TPUv3 Core batch size 16GB HBM. `factor` in range(1, 26)
factor
=
1
num_cores
=
32
# for TPU 4x4
train_per_core_bs
=
32
*
factor
train_bs
=
train_per_core_bs
*
num_cores
eval_per_core_bs
=
32
*
50
# multiplier<=100
eval_bs
=
eval_per_core_bs
*
num_cores
# based lr=0.0001 for bs=512
return
add_trainer
(
exp_config
,
train_batch_size
=
train_bs
,
eval_batch_size
=
eval_bs
,
learning_rate
=
0.0001
*
(
train_bs
/
512
),
train_epochs
=
100
)
official/vision/beta/projects/yt8m/modeling/yt8m_agg_models.py
View file @
e9355843
...
@@ -13,13 +13,12 @@
...
@@ -13,13 +13,12 @@
# limitations under the License.
# limitations under the License.
"""Contains model definitions."""
"""Contains model definitions."""
from
typing
import
Optional
,
Dict
,
Any
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.projects.yt8m.modeling
import
yt8m_model_utils
as
utils
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
regularizers
=
tf
.
keras
.
regularizers
# The number of mixtures (excluding the dummy 'expert') used for MoeModel.
moe_num_mixtures
=
2
class
LogisticModel
():
class
LogisticModel
():
...
@@ -41,7 +40,7 @@ class LogisticModel():
...
@@ -41,7 +40,7 @@ class LogisticModel():
output
=
layers
.
Dense
(
output
=
layers
.
Dense
(
vocab_size
,
vocab_size
,
activation
=
tf
.
nn
.
sigmoid
,
activation
=
tf
.
nn
.
sigmoid
,
kernel_regularizer
=
regularizers
.
l2
(
l2_penalty
))(
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
l2_penalty
))(
model_input
)
model_input
)
return
{
"predictions"
:
output
}
return
{
"predictions"
:
output
}
...
@@ -52,8 +51,12 @@ class MoeModel():
...
@@ -52,8 +51,12 @@ class MoeModel():
def
create_model
(
self
,
def
create_model
(
self
,
model_input
,
model_input
,
vocab_size
,
vocab_size
,
num_mixtures
=
None
,
num_mixtures
:
int
=
2
,
l2_penalty
=
1e-8
):
use_input_context_gate
:
bool
=
False
,
use_output_context_gate
:
bool
=
False
,
normalizer_fn
=
None
,
normalizer_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
l2_penalty
:
float
=
1e-5
):
"""Creates a Mixture of (Logistic) Experts model.
"""Creates a Mixture of (Logistic) Experts model.
The model consists of a per-class softmax distribution over a
The model consists of a per-class softmax distribution over a
...
@@ -64,6 +67,10 @@ class MoeModel():
...
@@ -64,6 +67,10 @@ class MoeModel():
vocab_size: The number of classes in the dataset.
vocab_size: The number of classes in the dataset.
num_mixtures: The number of mixtures (excluding a dummy 'expert' that
num_mixtures: The number of mixtures (excluding a dummy 'expert' that
always predicts the non-existence of an entity).
always predicts the non-existence of an entity).
use_input_context_gate: if True apply context gate layer to the input.
use_output_context_gate: if True apply context gate layer to the output.
normalizer_fn: normalization op constructor (e.g. batch norm).
normalizer_params: parameters to the `normalizer_fn`.
l2_penalty: How much to penalize the squared magnitudes of parameter
l2_penalty: How much to penalize the squared magnitudes of parameter
values.
values.
...
@@ -72,18 +79,23 @@ class MoeModel():
...
@@ -72,18 +79,23 @@ class MoeModel():
of the model in the 'predictions' key. The dimensions of the tensor
of the model in the 'predictions' key. The dimensions of the tensor
are batch_size x num_classes.
are batch_size x num_classes.
"""
"""
num_mixtures
=
num_mixtures
or
moe_num_mixtures
if
use_input_context_gate
:
model_input
=
utils
.
context_gate
(
model_input
,
normalizer_fn
=
normalizer_fn
,
normalizer_params
=
normalizer_params
,
)
gate_activations
=
layers
.
Dense
(
gate_activations
=
layers
.
Dense
(
vocab_size
*
(
num_mixtures
+
1
),
vocab_size
*
(
num_mixtures
+
1
),
activation
=
None
,
activation
=
None
,
bias_initializer
=
None
,
bias_initializer
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
l2_penalty
))(
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
l2_penalty
))(
model_input
)
model_input
)
expert_activations
=
layers
.
Dense
(
expert_activations
=
layers
.
Dense
(
vocab_size
*
num_mixtures
,
vocab_size
*
num_mixtures
,
activation
=
None
,
activation
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
l2_penalty
))(
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
l2_penalty
))(
model_input
)
model_input
)
gating_distribution
=
tf
.
nn
.
softmax
(
gating_distribution
=
tf
.
nn
.
softmax
(
...
@@ -98,4 +110,10 @@ class MoeModel():
...
@@ -98,4 +110,10 @@ class MoeModel():
gating_distribution
[:,
:
num_mixtures
]
*
expert_distribution
,
1
)
gating_distribution
[:,
:
num_mixtures
]
*
expert_distribution
,
1
)
final_probabilities
=
tf
.
reshape
(
final_probabilities_by_class_and_batch
,
final_probabilities
=
tf
.
reshape
(
final_probabilities_by_class_and_batch
,
[
-
1
,
vocab_size
])
[
-
1
,
vocab_size
])
if
use_output_context_gate
:
final_probabilities
=
utils
.
context_gate
(
final_probabilities
,
normalizer_fn
=
normalizer_fn
,
normalizer_params
=
normalizer_params
,
)
return
{
"predictions"
:
final_probabilities
}
return
{
"predictions"
:
final_probabilities
}
official/vision/beta/projects/yt8m/modeling/yt8m_model.py
View file @
e9355843
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
"""YT8M model definition."""
"""YT8M model definition."""
from
typing
import
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
...
@@ -23,23 +24,43 @@ from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as util
...
@@ -23,23 +24,43 @@ from official.vision.beta.projects.yt8m.modeling import yt8m_model_utils as util
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
class
YT8M
Model
(
tf
.
keras
.
Model
):
class
Dbof
Model
(
tf
.
keras
.
Model
):
"""A YT8M model class builder.
"""
"""A YT8M model class builder.
def
__init__
(
self
,
Creates a Deep Bag of Frames model.
input_params
:
yt8m_cfg
.
YT8MModel
,
The model projects the features for each frame into a higher dimensional
'clustering' space, pools across frames in that space, and then
uses a configurable video-level model to classify the now aggregated features.
The model will randomly sample either frames or sequences of frames during
training to speed up convergence.
"""
def
__init__
(
self
,
params
:
yt8m_cfg
.
DbofModel
,
num_frames
=
30
,
num_frames
=
30
,
num_classes
=
3862
,
num_classes
=
3862
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
1152
]),
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
1152
]),
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
activation
:
str
=
"relu"
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
**
kwargs
):
**
kwargs
):
"""YT8M initialization function.
"""YT8M initialization function.
Args:
Args:
input_
params: model configuration parameters
params: model configuration parameters
num_frames: `int` number of frames in a single input.
num_frames: `int` number of frames in a single input.
num_classes: `int` number of classes in dataset.
num_classes: `int` number of classes in dataset.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
[batch_size x num_frames x num_features]
[batch_size x num_frames x num_features]
kernel_regularizer: tf.keras.regularizers.Regularizer object. Default to
None.
activation: A `str` of name of the activation function.
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
**kwargs: keyword arguments to be passed.
**kwargs: keyword arguments to be passed.
"""
"""
...
@@ -48,12 +69,19 @@ class YT8MModel(tf.keras.Model):
...
@@ -48,12 +69,19 @@ class YT8MModel(tf.keras.Model):
"input_specs"
:
input_specs
,
"input_specs"
:
input_specs
,
"num_classes"
:
num_classes
,
"num_classes"
:
num_classes
,
"num_frames"
:
num_frames
,
"num_frames"
:
num_frames
,
"
input_
params"
:
input_
params
"params"
:
params
}
}
self
.
_num_classes
=
num_classes
self
.
_num_classes
=
num_classes
self
.
_input_specs
=
input_specs
self
.
_input_specs
=
input_specs
self
.
_act_fn
=
tf_utils
.
get_activation
(
input_params
.
activation
)
self
.
_act_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_is_training
=
input_params
.
is_training
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
"channels_last"
:
bn_axis
=
-
1
else
:
bn_axis
=
1
# [batch_size x num_frames x num_features]
# [batch_size x num_frames x num_features]
feature_size
=
input_specs
.
shape
[
-
1
]
feature_size
=
input_specs
.
shape
[
-
1
]
...
@@ -63,31 +91,34 @@ class YT8MModel(tf.keras.Model):
...
@@ -63,31 +91,34 @@ class YT8MModel(tf.keras.Model):
tf
.
summary
.
histogram
(
"input_hist"
,
model_input
)
tf
.
summary
.
histogram
(
"input_hist"
,
model_input
)
# configure model
# configure model
if
input_params
.
add_batch_norm
:
if
params
.
add_batch_norm
:
reshaped_input
=
layers
.
BatchNormalization
(
reshaped_input
=
self
.
_norm
(
name
=
"input_bn"
,
scale
=
True
,
center
=
True
,
axis
=
bn_axis
,
trainable
=
self
.
_is_training
)(
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
name
=
"input_bn"
)(
reshaped_input
)
reshaped_input
)
# activation = reshaped input * cluster weights
# activation = reshaped input * cluster weights
if
params
.
cluster_size
>
0
:
activation
=
layers
.
Dense
(
activation
=
layers
.
Dense
(
input_params
.
cluster_size
,
params
.
cluster_size
,
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
tf
.
random_normal_initializer
(
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
1
/
tf
.
sqrt
(
tf
.
cast
(
feature_size
,
tf
.
float32
))))(
stddev
=
1
/
tf
.
sqrt
(
tf
.
cast
(
feature_size
,
tf
.
float32
))))(
reshaped_input
)
reshaped_input
)
if
input_
params
.
add_batch_norm
:
if
params
.
add_batch_norm
:
activation
=
layers
.
BatchNormalization
(
activation
=
self
.
_norm
(
name
=
"cluster_bn"
,
axis
=
bn_axis
,
scale
=
True
,
momentum
=
norm_momentum
,
center
=
True
,
epsilon
=
norm_epsilon
,
trainable
=
self
.
_is_training
)(
name
=
"cluster_bn"
)(
activation
)
activation
)
else
:
else
:
cluster_biases
=
tf
.
Variable
(
cluster_biases
=
tf
.
Variable
(
tf
.
random_normal_initializer
(
stddev
=
1
/
tf
.
math
.
sqrt
(
feature_size
))(
tf
.
random_normal_initializer
(
stddev
=
1
/
tf
.
math
.
sqrt
(
feature_size
))(
shape
=
[
input_
params
.
cluster_size
]),
shape
=
[
params
.
cluster_size
]),
name
=
"cluster_biases"
)
name
=
"cluster_biases"
)
tf
.
summary
.
histogram
(
"cluster_biases"
,
cluster_biases
)
tf
.
summary
.
histogram
(
"cluster_biases"
,
cluster_biases
)
activation
+=
cluster_biases
activation
+=
cluster_biases
...
@@ -95,30 +126,42 @@ class YT8MModel(tf.keras.Model):
...
@@ -95,30 +126,42 @@ class YT8MModel(tf.keras.Model):
activation
=
self
.
_act_fn
(
activation
)
activation
=
self
.
_act_fn
(
activation
)
tf
.
summary
.
histogram
(
"cluster_output"
,
activation
)
tf
.
summary
.
histogram
(
"cluster_output"
,
activation
)
activation
=
tf
.
reshape
(
activation
,
if
params
.
use_context_gate_cluster_layer
:
[
-
1
,
num_frames
,
input_params
.
cluster_size
])
pooling_method
=
None
activation
=
utils
.
FramePooling
(
activation
,
input_params
.
pooling_method
)
norm_args
=
dict
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
name
=
"context_gate_bn"
)
activation
=
utils
.
context_gate
(
activation
,
normalizer_fn
=
self
.
_norm
,
normalizer_params
=
norm_args
,
pooling_method
=
pooling_method
,
hidden_layer_size
=
params
.
context_gate_cluster_bottleneck_size
,
kernel_regularizer
=
kernel_regularizer
)
activation
=
tf
.
reshape
(
activation
,
[
-
1
,
num_frames
,
params
.
cluster_size
])
activation
=
utils
.
frame_pooling
(
activation
,
params
.
pooling_method
)
# activation = activation * hidden1_weights
# activation = activation * hidden1_weights
activation
=
layers
.
Dense
(
activation
=
layers
.
Dense
(
input_params
.
hidden_size
,
params
.
hidden_size
,
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
tf
.
random_normal_initializer
(
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
1
/
stddev
=
1
/
tf
.
sqrt
(
tf
.
cast
(
params
.
cluster_size
,
tf
.
float32
))))(
tf
.
sqrt
(
tf
.
cast
(
input_params
.
cluster_size
,
tf
.
float32
))))(
activation
)
activation
)
if
input_
params
.
add_batch_norm
:
if
params
.
add_batch_norm
:
activation
=
layers
.
BatchNormalization
(
activation
=
self
.
_norm
(
name
=
"hidden1_bn"
,
axis
=
bn_axis
,
scale
=
True
,
momentum
=
norm_momentum
,
center
=
True
,
epsilon
=
norm_epsilon
,
trainable
=
self
.
_is_training
)(
name
=
"hidden1_bn"
)(
activation
)
activation
)
else
:
else
:
hidden1_biases
=
tf
.
Variable
(
hidden1_biases
=
tf
.
Variable
(
tf
.
random_normal_initializer
(
stddev
=
0.01
)(
tf
.
random_normal_initializer
(
stddev
=
0.01
)(
shape
=
[
params
.
hidden_size
]),
shape
=
[
input_params
.
hidden_size
]),
name
=
"hidden1_biases"
)
name
=
"hidden1_biases"
)
tf
.
summary
.
histogram
(
"hidden1_biases"
,
hidden1_biases
)
tf
.
summary
.
histogram
(
"hidden1_biases"
,
hidden1_biases
)
...
@@ -128,9 +171,15 @@ class YT8MModel(tf.keras.Model):
...
@@ -128,9 +171,15 @@ class YT8MModel(tf.keras.Model):
tf
.
summary
.
histogram
(
"hidden1_output"
,
activation
)
tf
.
summary
.
histogram
(
"hidden1_output"
,
activation
)
aggregated_model
=
getattr
(
yt8m_agg_models
,
aggregated_model
=
getattr
(
yt8m_agg_models
,
input_params
.
yt8m_agg_classifier_model
)
params
.
yt8m_agg_classifier_model
)
norm_args
=
dict
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)
output
=
aggregated_model
().
create_model
(
output
=
aggregated_model
().
create_model
(
model_input
=
activation
,
vocab_size
=
self
.
_num_classes
)
model_input
=
activation
,
vocab_size
=
self
.
_num_classes
,
num_mixtures
=
params
.
agg_model
.
num_mixtures
,
normalizer_fn
=
self
.
_norm
,
normalizer_params
=
norm_args
,
l2_penalty
=
params
.
agg_model
.
l2_penalty
)
super
().
__init__
(
super
().
__init__
(
inputs
=
model_input
,
outputs
=
output
.
get
(
"predictions"
),
**
kwargs
)
inputs
=
model_input
,
outputs
=
output
.
get
(
"predictions"
),
**
kwargs
)
...
...
official/vision/beta/projects/yt8m/modeling/yt8m_model_test.py
View file @
e9355843
...
@@ -37,8 +37,8 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -37,8 +37,8 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
num_frames
,
feature_dims
])
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
num_frames
,
feature_dims
])
num_classes
=
3862
num_classes
=
3862
model
=
yt8m_model
.
YT8M
Model
(
model
=
yt8m_model
.
Dbof
Model
(
input_
params
=
yt8m_cfg
.
YT8MTask
.
model
,
params
=
yt8m_cfg
.
YT8MTask
.
model
,
num_frames
=
num_frames
,
num_frames
=
num_frames
,
num_classes
=
num_classes
,
num_classes
=
num_classes
,
input_specs
=
input_specs
)
input_specs
=
input_specs
)
...
@@ -49,10 +49,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -49,10 +49,10 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllEqual
([
2
,
num_classes
],
logits
.
numpy
().
shape
)
self
.
assertAllEqual
([
2
,
num_classes
],
logits
.
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
model
=
yt8m_model
.
YT8M
Model
(
input_
params
=
yt8m_cfg
.
YT8MTask
.
model
)
model
=
yt8m_model
.
Dbof
Model
(
params
=
yt8m_cfg
.
YT8MTask
.
model
)
config
=
model
.
get_config
()
config
=
model
.
get_config
()
new_model
=
yt8m_model
.
YT8M
Model
.
from_config
(
config
)
new_model
=
yt8m_model
.
Dbof
Model
.
from_config
(
config
)
# If the serialization was successful,
# If the serialization was successful,
# the new config should match the old.
# the new config should match the old.
...
...
official/vision/beta/projects/yt8m/modeling/yt8m_model_utils.py
View file @
e9355843
...
@@ -13,10 +13,12 @@
...
@@ -13,10 +13,12 @@
# limitations under the License.
# limitations under the License.
"""Contains a collection of util functions for model construction."""
"""Contains a collection of util functions for model construction."""
from
typing
import
Dict
,
Optional
,
Union
,
Any
import
tensorflow
as
tf
import
tensorflow
as
tf
def
S
ample
R
andom
S
equence
(
model_input
,
num_frames
,
num_samples
):
def
s
ample
_r
andom
_s
equence
(
model_input
,
num_frames
,
num_samples
):
"""Samples a random sequence of frames of size num_samples.
"""Samples a random sequence of frames of size num_samples.
Args:
Args:
...
@@ -44,7 +46,7 @@ def SampleRandomSequence(model_input, num_frames, num_samples):
...
@@ -44,7 +46,7 @@ def SampleRandomSequence(model_input, num_frames, num_samples):
return
tf
.
gather_nd
(
model_input
,
index
)
return
tf
.
gather_nd
(
model_input
,
index
)
def
S
ample
R
andom
F
rames
(
model_input
,
num_frames
,
num_samples
):
def
s
ample
_r
andom
_f
rames
(
model_input
,
num_frames
,
num_samples
):
"""Samples a random set of frames of size num_samples.
"""Samples a random set of frames of size num_samples.
Args:
Args:
...
@@ -66,7 +68,7 @@ def SampleRandomFrames(model_input, num_frames, num_samples):
...
@@ -66,7 +68,7 @@ def SampleRandomFrames(model_input, num_frames, num_samples):
return
tf
.
gather_nd
(
model_input
,
index
)
return
tf
.
gather_nd
(
model_input
,
index
)
def
F
rame
P
ooling
(
frames
,
method
):
def
f
rame
_p
ooling
(
frames
,
method
):
"""Pools over the frames of a video.
"""Pools over the frames of a video.
Args:
Args:
...
@@ -93,3 +95,110 @@ def FramePooling(frames, method):
...
@@ -93,3 +95,110 @@ def FramePooling(frames, method):
raise
ValueError
(
"Unrecognized pooling method: %s"
%
method
)
raise
ValueError
(
"Unrecognized pooling method: %s"
%
method
)
return
reduced
return
reduced
def
context_gate
(
input_features
,
normalizer_fn
=
None
,
normalizer_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
kernel_initializer
:
Union
[
str
,
tf
.
keras
.
regularizers
.
Regularizer
]
=
"glorot_uniform"
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_initializer
:
Union
[
str
,
tf
.
keras
.
regularizers
.
Regularizer
]
=
"zeros"
,
hidden_layer_size
:
int
=
0
,
pooling_method
:
Optional
[
str
]
=
None
,
additive_residual
:
bool
=
False
):
"""Context Gating.
More details: https://arxiv.org/pdf/1706.06905.pdf.
Args:
input_features: a tensor of at least rank 2.
normalizer_fn: Normalization function to use instead of `biases` (e.g.
tf.contrib.layers.batch_norm). If None, bias is added.
normalizer_params: Normalization function parameters.
kernel_initializer: Weight initializer to use instead of Xavier (e.g.
tf.contrib.layers.variance_scaling_initializer).
kernel_regularizer: Weight regularizer to use instead of None (e.g.,
tf.contrib.layers.l2_regularizer(l2_penalty)).
bias_initializer: Biases initializer to use (default tf.zeros_initializer)
hidden_layer_size: Dimensionality of the context gating hidden layer size,
if any. If None, will apply a fully-connected context gating layer with
shape [input_size x input_size]. If set to an int N, will factorize the
context gating layer into [input_size x N] x [N x input_size] as in the
squeeze-and-excitation block from https://arxiv.org/pdf/1709.01507.pdf.
pooling_method: Whether to perform global pooling of the local features
before applying the context gating layer. This is relevant only if the
input_features tensor has rank > 2, e.g., it's a sequence of frame
features, [batch_size, num_frames, feature_dim], or spatial convolution
features, [batch_size*num_frames, h, w, feature_dim]. If the inputs are a
set of local features and pooling_method is not None, will pool features
across all but the batch_size dimension using the specified pooling
method, and pass the aggregated features as context to the gating layer.
For a list of pooling methods, see the frame_pooling() function.
additive_residual: If true, will use ReLu6-activated (additive) residual
connections instead of Sigmoid-activated (multiplicative) connections when
combining the input_features with the context gating branch.
Returns:
A tensor with the same shape as input_features.
"""
if
normalizer_params
is
None
:
normalizer_params
=
{}
with
tf
.
name_scope
(
"ContextGating"
):
num_dimensions
=
len
(
input_features
.
shape
.
as_list
())
feature_size
=
input_features
.
shape
.
as_list
()[
-
1
]
if
pooling_method
:
assert
num_dimensions
>
2
# Collapse the inner axes of the original features shape into a 3D tensor
original_shape
=
tf
.
shape
(
input_features
)
# The last dimension will change after concatenating the context
new_shape
=
tf
.
concat
(
[
original_shape
[:
-
1
],
tf
.
constant
([
2
*
feature_size
])],
0
)
batch_size
=
original_shape
[
0
]
reshaped_features
=
tf
.
reshape
(
input_features
,
[
batch_size
,
-
1
,
feature_size
])
num_features
=
tf
.
shape
(
reshaped_features
)[
1
]
# Pool the feature channels across the inner axes to get global context
context_features
=
frame_pooling
(
reshaped_features
,
pooling_method
)
context_features
=
tf
.
expand_dims
(
context_features
,
1
)
# Replicate the global context features and concat to the local features.
context_features
=
tf
.
tile
(
context_features
,
[
1
,
num_features
,
1
])
context_features
=
tf
.
concat
([
reshaped_features
,
context_features
],
2
)
context_features
=
tf
.
reshape
(
context_features
,
shape
=
new_shape
)
else
:
context_features
=
input_features
if
hidden_layer_size
>=
2
:
gates_bottleneck
=
tf
.
keras
.
layers
.
Dense
(
hidden_layer_size
,
activation
=
"relu6"
,
kernel_initializer
=
kernel_initializer
,
bias_initializer
=
bias_initializer
,
kernel_regularizer
=
kernel_regularizer
,
)(
context_features
)
if
normalizer_fn
:
gates_bottleneck
=
normalizer_fn
(
**
normalizer_params
)(
gates_bottleneck
)
else
:
gates_bottleneck
=
context_features
activation_fn
=
(
tf
.
nn
.
relu6
if
additive_residual
else
tf
.
nn
.
sigmoid
)
gates
=
tf
.
keras
.
layers
.
Dense
(
feature_size
,
activation
=
activation_fn
,
kernel_initializer
=
kernel_initializer
,
bias_initializer
=
bias_initializer
,
kernel_regularizer
=
kernel_regularizer
,
)(
gates_bottleneck
)
if
normalizer_fn
:
gates
=
normalizer_fn
(
**
normalizer_params
)(
gates
)
if
additive_residual
:
input_features
+=
gates
else
:
input_features
*=
gates
return
input_features
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment