Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
90585434
Commit
90585434
authored
Oct 17, 2022
by
Chaochao Yan
Committed by
A. Unique TensorFlower
Oct 17, 2022
Browse files
Internal change
PiperOrigin-RevId: 481733792
parent
d309fff8
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
849 additions
and
26 deletions
+849
-26
official/projects/vit/configs/image_classification.py
official/projects/vit/configs/image_classification.py
+3
-3
official/projects/vit/modeling/vit.py
official/projects/vit/modeling/vit.py
+1
-1
official/vision/configs/backbones.py
official/vision/configs/backbones.py
+29
-4
official/vision/configs/image_classification.py
official/vision/configs/image_classification.py
+198
-0
official/vision/configs/image_classification_test.py
official/vision/configs/image_classification_test.py
+4
-1
official/vision/modeling/backbones/__init__.py
official/vision/modeling/backbones/__init__.py
+1
-0
official/vision/modeling/backbones/vit.py
official/vision/modeling/backbones/vit.py
+322
-0
official/vision/modeling/backbones/vit_specs.py
official/vision/modeling/backbones/vit_specs.py
+68
-0
official/vision/modeling/backbones/vit_test.py
official/vision/modeling/backbones/vit_test.py
+73
-0
official/vision/modeling/classification_model_test.py
official/vision/modeling/classification_model_test.py
+33
-4
official/vision/modeling/layers/nn_blocks.py
official/vision/modeling/layers/nn_blocks.py
+117
-13
No files found.
official/projects/vit/configs/image_classification.py
View file @
90585434
...
@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)(
...
@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification
.
ImageClassificationTask
)
image_classification
.
ImageClassificationTask
)
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain'
)
@
exp_factory
.
register_config_factory
(
'
legacy_
deit_imagenet_pretrain'
)
def
image_classification_imagenet_deit_pretrain
()
->
cfg
.
ExperimentConfig
:
def
image_classification_imagenet_deit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
train_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
...
@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
...
@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
return
config
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
@
exp_factory
.
register_config_factory
(
'
legacy_
vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
train_batch_size
=
4096
...
@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
...
@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
return
config
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_finetune'
)
@
exp_factory
.
register_config_factory
(
'
legacy_
vit_imagenet_finetune'
)
def
image_classification_imagenet_vit_finetune
()
->
cfg
.
ExperimentConfig
:
def
image_classification_imagenet_vit_finetune
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
512
train_batch_size
=
512
...
...
official/projects/vit/modeling/vit.py
View file @
90585434
...
@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model):
...
@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model):
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
@
factory
.
register_backbone_builder
(
'vit'
)
@
factory
.
register_backbone_builder
(
'
legacy_
vit'
)
def
build_vit
(
input_specs
,
def
build_vit
(
input_specs
,
backbone_config
,
backbone_config
,
norm_activation_config
,
norm_activation_config
,
...
...
official/vision/configs/backbones.py
View file @
90585434
...
@@ -14,13 +14,37 @@
...
@@ -14,13 +14,37 @@
"""Backbones configurations."""
"""Backbones configurations."""
import
dataclasses
import
dataclasses
from
typing
import
Optional
,
List
from
typing
import
List
,
Optional
,
Tuple
# Import libraries
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
Transformer
(
hyperparams
.
Config
):
"""Transformer config."""
mlp_dim
:
int
=
1
num_heads
:
int
=
1
num_layers
:
int
=
1
attention_dropout_rate
:
float
=
0.0
dropout_rate
:
float
=
0.1
@
dataclasses
.
dataclass
class
VisionTransformer
(
hyperparams
.
Config
):
"""VisionTransformer config."""
model_name
:
str
=
'vit-b16'
# pylint: disable=line-too-long
pooler
:
str
=
'token'
# 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size
:
int
=
0
hidden_size
:
int
=
1
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
init_stochastic_depth_rate
:
float
=
0.0
original_init
:
bool
=
True
pos_embed_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ResNet
(
hyperparams
.
Config
):
class
ResNet
(
hyperparams
.
Config
):
"""ResNet config."""
"""ResNet config."""
...
@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
...
@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: mobile spinenet backbone config.
spinenet_mobile: mobile spinenet backbone config.
mobilenet: mobilenet backbone config.
mobilenet: mobilenet backbone config.
mobiledet: mobiledet backbone config.
mobiledet: mobiledet backbone config.
vit: vision transformer backbone config.
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
resnet
:
ResNet
=
ResNet
()
resnet
:
ResNet
=
ResNet
()
...
@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
...
@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile
:
SpineNetMobile
=
SpineNetMobile
()
spinenet_mobile
:
SpineNetMobile
=
SpineNetMobile
()
mobilenet
:
MobileNet
=
MobileNet
()
mobilenet
:
MobileNet
=
MobileNet
()
mobiledet
:
MobileDet
=
MobileDet
()
mobiledet
:
MobileDet
=
MobileDet
()
vit
:
VisionTransformer
=
VisionTransformer
()
official/vision/configs/image_classification.py
View file @
90585434
...
@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
...
@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
])
])
return
config
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain'
)
def
image_classification_imagenet_deit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
eval_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
label_smoothing
=
label_smoothing
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
eval_batch_size
=
4096
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
))),
losses
=
Losses
(
l2_weight_decay
=
0.0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
*
train_batch_size
/
4096
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
10000
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_finetune'
)
def
image_classification_imagenet_vit_finetune
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
512
eval_batch_size
=
512
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
384
,
384
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
))),
losses
=
Losses
(
l2_weight_decay
=
0.0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
20000
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
,
'global_clipnorm'
:
1.0
,
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'decay_steps'
:
20000
,
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/configs/image_classification_test.py
View file @
90585434
...
@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
(
'resnet_imagenet'
,),
(
'resnet_imagenet'
,),
(
'resnet_rs_imagenet'
,),
(
'resnet_rs_imagenet'
,),
(
'revnet_imagenet'
,),
(
'revnet_imagenet'
,),
(
'mobilenet_imagenet'
),
(
'mobilenet_imagenet'
,),
(
'deit_imagenet_pretrain'
,),
(
'vit_imagenet_pretrain'
,),
(
'vit_imagenet_finetune'
,),
)
)
def
test_image_classification_configs
(
self
,
config_name
):
def
test_image_classification_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
=
exp_factory
.
get_exp_config
(
config_name
)
...
...
official/vision/modeling/backbones/__init__.py
View file @
90585434
...
@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
...
@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
from
official.vision.modeling.backbones.revnet
import
RevNet
from
official.vision.modeling.backbones.revnet
import
RevNet
from
official.vision.modeling.backbones.spinenet
import
SpineNet
from
official.vision.modeling.backbones.spinenet
import
SpineNet
from
official.vision.modeling.backbones.spinenet_mobile
import
SpineNetMobile
from
official.vision.modeling.backbones.spinenet_mobile
import
SpineNetMobile
from
official.vision.modeling.backbones.vit
import
VisionTransformer
official/vision/modeling/backbones/vit.py
0 → 100644
View file @
90585434
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer models."""
from
typing
import
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.vision.modeling.backbones
import
factory
from
official.vision.modeling.backbones.vit_specs
import
VIT_SPECS
from
official.vision.modeling.layers
import
nn_blocks
from
official.vision.modeling.layers
import
nn_layers
layers
=
tf
.
keras
.
layers
class
AddPositionEmbs
(
tf
.
keras
.
layers
.
Layer
):
"""Adds (optionally learned) positional embeddings to the inputs."""
def
__init__
(
self
,
posemb_init
:
Optional
[
tf
.
keras
.
initializers
.
Initializer
]
=
None
,
posemb_origin_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
posemb_target_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
**
kwargs
):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super
().
__init__
(
**
kwargs
)
self
.
posemb_init
=
posemb_init
self
.
posemb_origin_shape
=
posemb_origin_shape
self
.
posemb_target_shape
=
posemb_target_shape
def
build
(
self
,
inputs_shape
):
if
self
.
posemb_origin_shape
is
not
None
:
pos_emb_length
=
self
.
posemb_origin_shape
[
0
]
*
self
.
posemb_origin_shape
[
1
]
else
:
pos_emb_length
=
inputs_shape
[
1
]
pos_emb_shape
=
(
1
,
pos_emb_length
,
inputs_shape
[
2
])
self
.
pos_embedding
=
self
.
add_weight
(
'pos_embedding'
,
pos_emb_shape
,
initializer
=
self
.
posemb_init
)
def
_interpolate
(
self
,
pos_embedding
:
tf
.
Tensor
,
from_shape
:
Tuple
[
int
,
int
],
to_shape
:
Tuple
[
int
,
int
])
->
tf
.
Tensor
:
"""Interpolates the positional embeddings."""
logging
.
info
(
'Interpolating postional embedding from length: %d to %d'
,
from_shape
,
to_shape
)
grid_emb
=
tf
.
reshape
(
pos_embedding
,
[
1
]
+
list
(
from_shape
)
+
[
-
1
])
# NOTE: Using BILINEAR interpolation by default.
grid_emb
=
tf
.
image
.
resize
(
grid_emb
,
to_shape
)
return
tf
.
reshape
(
grid_emb
,
[
1
,
to_shape
[
0
]
*
to_shape
[
1
],
-
1
])
def
call
(
self
,
inputs
,
inputs_positions
=
None
):
del
inputs_positions
pos_embedding
=
self
.
pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
if
inputs
.
shape
[
1
]
!=
pos_embedding
.
shape
[
1
]:
pos_embedding
=
self
.
_interpolate
(
pos_embedding
,
from_shape
=
self
.
posemb_origin_shape
,
to_shape
=
self
.
posemb_target_shape
)
pos_embedding
=
tf
.
cast
(
pos_embedding
,
inputs
.
dtype
)
return
inputs
+
pos_embedding
class
TokenLayer
(
tf
.
keras
.
layers
.
Layer
):
"""A simple layer to wrap token parameters."""
def
build
(
self
,
inputs_shape
):
self
.
cls
=
self
.
add_weight
(
'cls'
,
(
1
,
1
,
inputs_shape
[
-
1
]),
initializer
=
'zeros'
)
def
call
(
self
,
inputs
):
cls
=
tf
.
cast
(
self
.
cls
,
inputs
.
dtype
)
cls
=
cls
+
tf
.
zeros_like
(
inputs
[:,
0
:
1
])
# A hacky way to tile.
x
=
tf
.
concat
([
cls
,
inputs
],
axis
=
1
)
return
x
class
Encoder
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer Encoder."""
def
__init__
(
self
,
num_layers
,
mlp_dim
,
num_heads
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
add_pos_embed
=
True
,
pos_embed_origin_shape
=
None
,
pos_embed_target_shape
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_mlp_dim
=
mlp_dim
self
.
_num_heads
=
num_heads
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_add_pos_embed
=
add_pos_embed
self
.
_pos_embed_origin_shape
=
pos_embed_origin_shape
self
.
_pos_embed_target_shape
=
pos_embed_target_shape
def
build
(
self
,
input_shape
):
if
self
.
_add_pos_embed
:
self
.
_pos_embed
=
AddPositionEmbs
(
posemb_init
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.02
),
posemb_origin_shape
=
self
.
_pos_embed_origin_shape
,
posemb_target_shape
=
self
.
_pos_embed_target_shape
,
name
=
'posembed_input'
)
self
.
_dropout
=
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for
i
in
range
(
self
.
_num_layers
):
encoder_layer
=
nn_blocks
.
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
norm_first
=
True
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
1
,
self
.
_num_layers
),
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
x
=
inputs
if
self
.
_add_pos_embed
:
x
=
self
.
_pos_embed
(
x
,
inputs_positions
=
self
.
_inputs_positions
)
x
=
self
.
_dropout
(
x
,
training
=
training
)
for
encoder_layer
in
self
.
_encoder_layers
:
x
=
encoder_layer
(
x
,
training
=
training
)
x
=
self
.
_norm
(
x
)
return
x
def
get_config
(
self
):
config
=
super
().
get_config
()
updates
=
{
'num_layers'
:
self
.
_num_layers
,
'mlp_dim'
:
self
.
_mlp_dim
,
'num_heads'
:
self
.
_num_heads
,
'dropout_rate'
:
self
.
_dropout_rate
,
'attention_dropout_rate'
:
self
.
_attention_dropout_rate
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'inputs_positions'
:
self
.
_inputs_positions
,
'init_stochastic_depth_rate'
:
self
.
_init_stochastic_depth_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'add_pos_embed'
:
self
.
_add_pos_embed
,
'pos_embed_origin_shape'
:
self
.
_pos_embed_origin_shape
,
'pos_embed_target_shape'
:
self
.
_pos_embed_target_shape
,
}
config
.
update
(
updates
)
return
config
class
VisionTransformer
(
tf
.
keras
.
Model
):
"""Class to build VisionTransformer family model."""
def
__init__
(
self
,
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
init_stochastic_depth_rate
=
0.0
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
hidden_size
=
768
,
representation_size
=
0
,
pooler
=
'token'
,
kernel_regularizer
=
None
,
original_init
:
bool
=
True
,
pos_embed_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
):
"""VisionTransformer initialization function."""
self
.
_mlp_dim
=
mlp_dim
self
.
_num_heads
=
num_heads
self
.
_num_layers
=
num_layers
self
.
_hidden_size
=
hidden_size
self
.
_patch_size
=
patch_size
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
layers
.
Conv2D
(
filters
=
hidden_size
,
kernel_size
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
else
:
rows_axis
,
cols_axis
=
(
2
,
3
)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x
=
tf
.
transpose
(
x
,
perm
=
[
0
,
2
,
3
,
1
])
pos_embed_target_shape
=
(
x
.
shape
[
rows_axis
],
x
.
shape
[
cols_axis
])
seq_len
=
(
input_specs
.
shape
[
rows_axis
]
//
patch_size
)
*
(
input_specs
.
shape
[
cols_axis
]
//
patch_size
)
x
=
tf
.
reshape
(
x
,
[
-
1
,
seq_len
,
hidden_size
])
# If we want to add a class token, add it here.
if
pooler
==
'token'
:
x
=
TokenLayer
(
name
=
'cls'
)(
x
)
x
=
Encoder
(
num_layers
=
num_layers
,
mlp_dim
=
mlp_dim
,
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'glorot_uniform'
if
original_init
else
dict
(
class_name
=
'TruncatedNormal'
,
config
=
dict
(
stddev
=
.
02
)),
init_stochastic_depth_rate
=
init_stochastic_depth_rate
,
pos_embed_origin_shape
=
pos_embed_shape
,
pos_embed_target_shape
=
pos_embed_target_shape
)(
x
)
if
pooler
==
'token'
:
x
=
x
[:,
0
]
elif
pooler
==
'gap'
:
x
=
tf
.
reduce_mean
(
x
,
axis
=
1
)
elif
pooler
==
'none'
:
x
=
tf
.
identity
(
x
,
name
=
'encoded_tokens'
)
else
:
raise
ValueError
(
f
'unrecognized pooler type:
{
pooler
}
'
)
if
representation_size
:
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
x
=
tf
.
identity
(
x
,
name
=
'pre_logits'
)
if
pooler
==
'none'
:
endpoints
=
{
'encoded_tokens'
:
x
}
else
:
endpoints
=
{
'pre_logits'
:
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
@
factory
.
register_backbone_builder
(
'vit'
)
def
build_vit
(
input_specs
,
backbone_config
,
norm_activation_config
,
l2_regularizer
=
None
):
"""Build ViT model."""
del
norm_activation_config
backbone_type
=
backbone_config
.
type
backbone_cfg
=
backbone_config
.
get
()
assert
backbone_type
==
'vit'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
backbone_cfg
.
override
(
VIT_SPECS
[
backbone_cfg
.
model_name
])
return
VisionTransformer
(
mlp_dim
=
backbone_cfg
.
transformer
.
mlp_dim
,
num_heads
=
backbone_cfg
.
transformer
.
num_heads
,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
init_stochastic_depth_rate
=
backbone_cfg
.
init_stochastic_depth_rate
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
pooler
=
backbone_cfg
.
pooler
,
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
,
pos_embed_shape
=
backbone_cfg
.
pos_embed_shape
)
official/vision/modeling/backbones/vit_specs.py
0 → 100644
View file @
90585434
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer backbone specs."""
import
immutabledict
VIT_SPECS
=
immutabledict
.
immutabledict
({
'vit-ti16'
:
dict
(
hidden_size
=
192
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
768
,
num_heads
=
3
,
num_layers
=
12
),
),
'vit-s16'
:
dict
(
hidden_size
=
384
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1536
,
num_heads
=
6
,
num_layers
=
12
),
),
'vit-b16'
:
dict
(
hidden_size
=
768
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
),
),
'vit-b32'
:
dict
(
hidden_size
=
768
,
patch_size
=
32
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
),
),
'vit-l16'
:
dict
(
hidden_size
=
1024
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
4096
,
num_heads
=
16
,
num_layers
=
24
),
),
'vit-l32'
:
dict
(
hidden_size
=
1024
,
patch_size
=
32
,
transformer
=
dict
(
mlp_dim
=
4096
,
num_heads
=
16
,
num_layers
=
24
),
),
'vit-h14'
:
dict
(
hidden_size
=
1280
,
patch_size
=
14
,
transformer
=
dict
(
mlp_dim
=
5120
,
num_heads
=
16
,
num_layers
=
32
),
),
'vit-g14'
:
dict
(
hidden_size
=
1664
,
patch_size
=
14
,
transformer
=
dict
(
mlp_dim
=
8192
,
num_heads
=
16
,
num_layers
=
48
),
),
})
official/vision/modeling/backbones/vit_test.py
0 → 100644
View file @
90585434
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for VIT."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.modeling.backbones
import
vit
class
VisionTransformerTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
224
,
85798656
),
(
256
,
85844736
),
)
def
test_network_creation
(
self
,
input_size
,
params_count
):
"""Test creation of VisionTransformer family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
2
,
input_size
,
input_size
,
3
])
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
_
=
network
(
inputs
)
self
.
assertEqual
(
network
.
count_params
(),
params_count
)
def
test_network_none_pooler
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_size
=
256
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
2
,
input_size
,
input_size
,
3
])
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
,
patch_size
=
16
,
pooler
=
'none'
,
representation_size
=
128
,
pos_embed_shape
=
(
14
,
14
))
# (224 // 16)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
output
=
network
(
inputs
)[
'encoded_tokens'
]
self
.
assertEqual
(
output
.
shape
,
[
1
,
256
,
128
])
def
test_posembedding_interpolation
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_size
=
256
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
2
,
input_size
,
input_size
,
3
])
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
,
patch_size
=
16
,
pooler
=
'gap'
,
pos_embed_shape
=
(
14
,
14
))
# (224 // 16)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
output
=
network
(
inputs
)[
'pre_logits'
]
self
.
assertEqual
(
output
.
shape
,
[
1
,
1
,
1
,
768
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/modeling/classification_model_test.py
View file @
90585434
...
@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model
...
@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model
class
ClassificationNetworkTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
ClassificationNetworkTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
192
*
4
,
3
,
12
,
192
,
5524416
),
(
384
*
4
,
6
,
12
,
384
,
21665664
),
)
def
test_vision_transformer_creation
(
self
,
mlp_dim
,
num_heads
,
num_layers
,
hidden_size
,
num_params
):
"""Test for creation of a Vision Transformer classifier."""
inputs
=
np
.
random
.
rand
(
2
,
224
,
224
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
VisionTransformer
(
mlp_dim
=
mlp_dim
,
num_heads
=
num_heads
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
224
,
224
,
3
]),
)
self
.
assertEqual
(
backbone
.
count_params
(),
num_params
)
num_classes
=
1000
model
=
classification_model
.
ClassificationModel
(
backbone
=
backbone
,
num_classes
=
num_classes
,
dropout_rate
=
0.2
,
)
logits
=
model
(
inputs
)
self
.
assertAllEqual
([
2
,
num_classes
],
logits
.
numpy
().
shape
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
128
,
50
,
'relu'
),
(
128
,
50
,
'relu'
),
(
128
,
50
,
'relu'
),
(
128
,
50
,
'relu'
),
(
128
,
50
,
'swish'
),
(
128
,
50
,
'swish'
),
)
)
def
test_resnet_network_creation
(
def
test_resnet_network_creation
(
self
,
input_size
,
resnet_model_id
,
self
,
input_size
,
resnet_model_id
,
activation
):
activation
):
"""Test for creation of a ResNet-50 classifier."""
"""Test for creation of a ResNet-50 classifier."""
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
ResNet
(
backbone
=
backbones
.
ResNet
(
model_id
=
resnet_model_id
,
activation
=
activation
)
model_id
=
resnet_model_id
,
activation
=
activation
)
self
.
assertEqual
(
backbone
.
count_params
(),
23561152
)
self
.
assertEqual
(
backbone
.
count_params
(),
23561152
)
num_classes
=
1000
num_classes
=
1000
...
...
official/vision/modeling/layers/nn_blocks.py
View file @
90585434
...
@@ -21,6 +21,7 @@ from absl import logging
...
@@ -21,6 +21,7 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp
import
modeling
as
nlp_modeling
from
official.vision.modeling.layers
import
nn_layers
from
official.vision.modeling.layers
import
nn_layers
...
@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
...
@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating
se_gating_activation: A `str` name of squeeze-excitation gating
activation.
activation.
se_round_down_protect: A `bool` of whether round down more than 10%
se_round_down_protect: A `bool` of whether round down more than 10%
will
will
be allowed in SE layer.
be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in
expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer.
squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for
depthwise_activation: A `str` name of the activation function for
...
@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
...
@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
dilation_rate: An `int` that specifies the dilation rate to use for.
dilation_rate: An `int` that specifies the dilation rate to use for.
divisible_by: An `int` that ensures all inner dimensions are divisible by
divisible_by: An `int` that ensures all inner dimensions are divisible by
this number.
this number. dilated convolution: An `int` to specify the same value for
dilated convolution: An `int` to specify the same value for all spatial
all spatial dimensions.
dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on
regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise.
depthwise.
use_depthwise: A `bool` of whether to uses fused convolutions instead of
use_depthwise: A `bool` of whether to uses fused convolutions instead of
...
@@ -1048,7 +1048,7 @@ class ReversibleLayer(tf.keras.layers.Layer):
...
@@ -1048,7 +1048,7 @@ class ReversibleLayer(tf.keras.layers.Layer):
(bottleneck) residual functions. Where the input to the reversible layer
(bottleneck) residual functions. Where the input to the reversible layer
is x, the input gets partitioned in the channel dimension and the
is x, the input gets partitioned in the channel dimension and the
forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 +
forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 +
g(z1), y1 = stop_gradient(z1).
g(z1), y1 = stop_gradient(z1).
g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in
g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in
paper. Detailed explanation same as above as `f` arg.
paper. Detailed explanation same as above as `f` arg.
manual_grads: A `bool` [Testing Only] of whether to manually take
manual_grads: A `bool` [Testing Only] of whether to manually take
...
@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer):
...
@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer):
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
DepthwiseSeparableConvBlock
(
tf
.
keras
.
layers
.
Layer
):
class
DepthwiseSeparableConvBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Creates an depthwise separable convolution block with batch normalization."""
"""Creates a depthwise separable convolution block with batch normalization.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer):
...
@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer):
Args:
Args:
in_filters: An `int` number of filters of the input tensor.
in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor.
out_filters: An `int` number of filters of the output tensor.
input_compression_ratio: An `float` of compression ratio for
input_compression_ratio: An `float` of compression ratio for
input
input
filters.
filters.
output_compression_ratio: An `float` of compression ratio for
output_compression_ratio: An `float` of compression ratio for
output
output
filters.
filters.
strides: An `int` block stride. If greater than 1, this block will
strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input.
ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer.
kernel_size: An `int` kernel_size of the depthwise conv layer.
...
@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer):
...
@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer):
x
=
self
.
_conv2
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
x
=
self
.
_norm2
(
x
)
if
(
self
.
_use_residual
and
if
(
self
.
_use_residual
and
self
.
_in_filters
==
self
.
_out_filters
and
self
.
_in_filters
==
self
.
_out_filters
and
self
.
_strides
==
1
):
self
.
_strides
==
1
):
if
self
.
_stochastic_depth
:
if
self
.
_stochastic_depth
:
x
=
self
.
_stochastic_depth
(
x
,
training
=
training
)
x
=
self
.
_stochastic_depth
(
x
,
training
=
training
)
x
=
self
.
_add
([
x
,
shortcut
])
x
=
self
.
_add
([
x
,
shortcut
])
return
x
return
x
class
TransformerEncoderBlock
(
nlp_modeling
.
layers
.
TransformerEncoderBlock
):
"""TransformerEncoderBlock layer with stochastic depth."""
def
__init__
(
self
,
*
args
,
stochastic_depth_drop_rate
=
0.0
,
return_attention
=
False
,
**
kwargs
):
"""Initializes TransformerEncoderBlock."""
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_return_attention
=
return_attention
def
build
(
self
,
input_shape
):
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
nn_layers
.
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
lambda
x
,
*
args
,
**
kwargs
:
tf
.
identity
(
x
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
"""Transformer self-attention encoder block call."""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
'Unexpected inputs to %s with length at %d'
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
,
attention_scores
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
,
return_attention_scores
=
True
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
if
self
.
_return_attention
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
),
attention_scores
else
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
if
self
.
_return_attention
:
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
)),
attention_scores
else
:
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment