Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bdca62cc
Commit
bdca62cc
authored
Sep 01, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 394355848
parent
c5b6d8da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
17 deletions
+59
-17
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+1
-0
official/vision/beta/modeling/backbones/resnet.py
official/vision/beta/modeling/backbones/resnet.py
+31
-8
official/vision/beta/modeling/backbones/resnet_test.py
official/vision/beta/modeling/backbones/resnet_test.py
+1
-0
official/vision/beta/modeling/layers/nn_blocks.py
official/vision/beta/modeling/layers/nn_blocks.py
+26
-9
No files found.
official/vision/beta/configs/backbones.py
View file @
bdca62cc
...
@@ -32,6 +32,7 @@ class ResNet(hyperparams.Config):
...
@@ -32,6 +32,7 @@ class ResNet(hyperparams.Config):
stochastic_depth_drop_rate
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
resnetd_shortcut
:
bool
=
False
resnetd_shortcut
:
bool
=
False
replace_stem_max_pool
:
bool
=
False
replace_stem_max_pool
:
bool
=
False
bn_trainable
:
bool
=
True
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/modeling/backbones/resnet.py
View file @
bdca62cc
...
@@ -127,6 +127,7 @@ class ResNet(tf.keras.Model):
...
@@ -127,6 +127,7 @@ class ResNet(tf.keras.Model):
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bn_trainable
:
bool
=
True
,
**
kwargs
):
**
kwargs
):
"""Initializes a ResNet model.
"""Initializes a ResNet model.
...
@@ -153,6 +154,8 @@ class ResNet(tf.keras.Model):
...
@@ -153,6 +154,8 @@ class ResNet(tf.keras.Model):
Conv2D. Default to None.
Conv2D. Default to None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
Default to None.
Default to None.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
**kwargs: Additional keyword arguments to be passed.
"""
"""
self
.
_model_id
=
model_id
self
.
_model_id
=
model_id
...
@@ -174,6 +177,7 @@ class ResNet(tf.keras.Model):
...
@@ -174,6 +177,7 @@ class ResNet(tf.keras.Model):
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_bn_trainable
=
bn_trainable
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
bn_axis
=
-
1
bn_axis
=
-
1
...
@@ -195,7 +199,10 @@ class ResNet(tf.keras.Model):
...
@@ -195,7 +199,10 @@ class ResNet(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
inputs
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
trainable
=
bn_trainable
)(
x
)
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
elif
stem_type
==
'v1'
:
elif
stem_type
==
'v1'
:
...
@@ -210,7 +217,10 @@ class ResNet(tf.keras.Model):
...
@@ -210,7 +217,10 @@ class ResNet(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
inputs
)
inputs
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
trainable
=
bn_trainable
)(
x
)
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
layers
.
Conv2D
(
x
=
layers
.
Conv2D
(
...
@@ -224,7 +234,10 @@ class ResNet(tf.keras.Model):
...
@@ -224,7 +234,10 @@ class ResNet(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
trainable
=
bn_trainable
)(
x
)
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
layers
.
Conv2D
(
x
=
layers
.
Conv2D
(
...
@@ -238,7 +251,10 @@ class ResNet(tf.keras.Model):
...
@@ -238,7 +251,10 @@ class ResNet(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
trainable
=
bn_trainable
)(
x
)
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
else
:
else
:
...
@@ -256,7 +272,10 @@ class ResNet(tf.keras.Model):
...
@@ -256,7 +272,10 @@ class ResNet(tf.keras.Model):
bias_regularizer
=
self
.
_bias_regularizer
)(
bias_regularizer
=
self
.
_bias_regularizer
)(
x
)
x
)
x
=
self
.
_norm
(
x
=
self
.
_norm
(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
)(
axis
=
bn_axis
,
momentum
=
norm_momentum
,
epsilon
=
norm_epsilon
,
trainable
=
bn_trainable
)(
x
)
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
x
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)(
x
)
else
:
else
:
...
@@ -324,7 +343,8 @@ class ResNet(tf.keras.Model):
...
@@ -324,7 +343,8 @@ class ResNet(tf.keras.Model):
activation
=
self
.
_activation
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
norm_epsilon
=
self
.
_norm_epsilon
,
bn_trainable
=
self
.
_bn_trainable
)(
inputs
)
inputs
)
for
_
in
range
(
1
,
block_repeats
):
for
_
in
range
(
1
,
block_repeats
):
...
@@ -341,7 +361,8 @@ class ResNet(tf.keras.Model):
...
@@ -341,7 +361,8 @@ class ResNet(tf.keras.Model):
activation
=
self
.
_activation
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
)(
norm_epsilon
=
self
.
_norm_epsilon
,
bn_trainable
=
self
.
_bn_trainable
)(
x
)
x
)
return
tf
.
keras
.
layers
.
Activation
(
'linear'
,
name
=
name
)(
x
)
return
tf
.
keras
.
layers
.
Activation
(
'linear'
,
name
=
name
)(
x
)
...
@@ -362,6 +383,7 @@ class ResNet(tf.keras.Model):
...
@@ -362,6 +383,7 @@ class ResNet(tf.keras.Model):
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'bn_trainable'
:
self
.
_bn_trainable
}
}
return
config_dict
return
config_dict
...
@@ -400,4 +422,5 @@ def build_resnet(
...
@@ -400,4 +422,5 @@ def build_resnet(
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
bn_trainable
=
backbone_cfg
.
bn_trainable
)
official/vision/beta/modeling/backbones/resnet_test.py
View file @
bdca62cc
...
@@ -135,6 +135,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -135,6 +135,7 @@ class ResNetTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer
=
'VarianceScaling'
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
bias_regularizer
=
None
,
bn_trainable
=
True
)
)
network
=
resnet
.
ResNet
(
**
kwargs
)
network
=
resnet
.
ResNet
(
**
kwargs
)
...
...
official/vision/beta/modeling/layers/nn_blocks.py
View file @
bdca62cc
...
@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -72,6 +72,7 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
bn_trainable
=
True
,
**
kwargs
):
**
kwargs
):
"""Initializes a residual block with BN after convolutions.
"""Initializes a residual block with BN after convolutions.
...
@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -99,6 +100,8 @@ class ResidualBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
**kwargs: Additional keyword arguments to be passed.
"""
"""
super
(
ResidualBlock
,
self
).
__init__
(
**
kwargs
)
super
(
ResidualBlock
,
self
).
__init__
(
**
kwargs
)
...
@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -126,6 +129,7 @@ class ResidualBlock(tf.keras.layers.Layer):
else
:
else
:
self
.
_bn_axis
=
1
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
self
.
_bn_trainable
=
bn_trainable
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
if
self
.
_use_projection
:
...
@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -140,7 +144,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self
.
_norm0
=
self
.
_norm
(
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
...
@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -154,7 +159,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self
.
_norm1
=
self
.
_norm
(
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
...
@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -168,7 +174,8 @@ class ResidualBlock(tf.keras.layers.Layer):
self
.
_norm2
=
self
.
_norm
(
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
if
self
.
_se_ratio
and
self
.
_se_ratio
>
0
and
self
.
_se_ratio
<=
1
:
if
self
.
_se_ratio
and
self
.
_se_ratio
>
0
and
self
.
_se_ratio
<=
1
:
self
.
_squeeze_excitation
=
nn_layers
.
SqueezeExcitation
(
self
.
_squeeze_excitation
=
nn_layers
.
SqueezeExcitation
(
...
@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer):
...
@@ -203,7 +210,8 @@ class ResidualBlock(tf.keras.layers.Layer):
'activation'
:
self
.
_activation
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
'norm_epsilon'
:
self
.
_norm_epsilon
,
'bn_trainable'
:
self
.
_bn_trainable
}
}
base_config
=
super
(
ResidualBlock
,
self
).
get_config
()
base_config
=
super
(
ResidualBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -249,6 +257,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn
=
False
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
norm_epsilon
=
0.001
,
bn_trainable
=
True
,
**
kwargs
):
**
kwargs
):
"""Initializes a standard bottleneck block with BN after convolutions.
"""Initializes a standard bottleneck block with BN after convolutions.
...
@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -277,6 +286,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
bn_trainable: A `bool` that indicates whether batch norm layers should be
trainable. Default to True.
**kwargs: Additional keyword arguments to be passed.
**kwargs: Additional keyword arguments to be passed.
"""
"""
super
(
BottleneckBlock
,
self
).
__init__
(
**
kwargs
)
super
(
BottleneckBlock
,
self
).
__init__
(
**
kwargs
)
...
@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -303,6 +314,7 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_bn_axis
=
-
1
self
.
_bn_axis
=
-
1
else
:
else
:
self
.
_bn_axis
=
1
self
.
_bn_axis
=
1
self
.
_bn_trainable
=
bn_trainable
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
if
self
.
_use_projection
:
...
@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -330,7 +342,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_norm0
=
self
.
_norm
(
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
filters
=
self
.
_filters
,
...
@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -343,7 +356,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_norm1
=
self
.
_norm
(
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_activation1
=
tf_utils
.
get_activation
(
self
.
_activation1
=
tf_utils
.
get_activation
(
self
.
_activation
,
use_keras_layer
=
True
)
self
.
_activation
,
use_keras_layer
=
True
)
...
@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -360,7 +374,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_norm2
=
self
.
_norm
(
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_activation2
=
tf_utils
.
get_activation
(
self
.
_activation2
=
tf_utils
.
get_activation
(
self
.
_activation
,
use_keras_layer
=
True
)
self
.
_activation
,
use_keras_layer
=
True
)
...
@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -375,7 +390,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
self
.
_norm3
=
self
.
_norm
(
self
.
_norm3
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
epsilon
=
self
.
_norm_epsilon
,
trainable
=
self
.
_bn_trainable
)
self
.
_activation3
=
tf_utils
.
get_activation
(
self
.
_activation3
=
tf_utils
.
get_activation
(
self
.
_activation
,
use_keras_layer
=
True
)
self
.
_activation
,
use_keras_layer
=
True
)
...
@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
...
@@ -414,7 +430,8 @@ class BottleneckBlock(tf.keras.layers.Layer):
'activation'
:
self
.
_activation
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
'norm_epsilon'
:
self
.
_norm_epsilon
,
'bn_trainable'
:
self
.
_bn_trainable
}
}
base_config
=
super
(
BottleneckBlock
,
self
).
get_config
()
base_config
=
super
(
BottleneckBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment