Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1f3247f4
Unverified
Commit
1f3247f4
authored
Mar 27, 2020
by
Ayushman Kumar
Committed by
GitHub
Mar 27, 2020
Browse files
Merge pull request #6 from tensorflow/master
Updated
parents
370a4c8d
0265f59c
Changes
85
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
123 additions
and
130 deletions
+123
-130
official/vision/image_classification/resnet/resnet_config.py
official/vision/image_classification/resnet/resnet_config.py
+61
-0
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+2
-2
official/vision/image_classification/resnet/resnet_imagenet_main.py
...ision/image_classification/resnet/resnet_imagenet_main.py
+2
-2
official/vision/image_classification/resnet/resnet_model.py
official/vision/image_classification/resnet/resnet_model.py
+56
-124
official/vision/image_classification/resnet/resnet_runnable.py
...ial/vision/image_classification/resnet/resnet_runnable.py
+2
-2
No files found.
official/vision/image_classification/resnet/resnet_config.py
0 → 100644
View file @
1f3247f4
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration definitions for ResNet losses, learning rates, and optimizers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Any
,
Mapping
import
dataclasses
from
official.vision.image_classification.configs
import
base_configs
_RESNET_LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
]
_RESNET_LR_BOUNDARIES
=
list
(
p
[
1
]
for
p
in
_RESNET_LR_SCHEDULE
[
1
:])
_RESNET_LR_MULTIPLIERS
=
list
(
p
[
0
]
for
p
in
_RESNET_LR_SCHEDULE
)
_RESNET_LR_WARMUP_EPOCHS
=
_RESNET_LR_SCHEDULE
[
0
][
1
]
@
dataclasses
.
dataclass
class
ResNetModelConfig
(
base_configs
.
ModelConfig
):
"""Configuration for the ResNet model."""
name
:
str
=
'ResNet'
num_classes
:
int
=
1000
model_params
:
Mapping
[
str
,
Any
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
{
'num_classes'
:
1000
,
'batch_size'
:
None
,
'use_l2_regularizer'
:
True
,
'rescale_inputs'
:
False
,
})
loss
:
base_configs
.
LossConfig
=
base_configs
.
LossConfig
(
name
=
'sparse_categorical_crossentropy'
)
optimizer
:
base_configs
.
OptimizerConfig
=
base_configs
.
OptimizerConfig
(
name
=
'momentum'
,
decay
=
0.9
,
epsilon
=
0.001
,
momentum
=
0.9
,
moving_average_decay
=
None
)
learning_rate
:
base_configs
.
LearningRateConfig
=
(
base_configs
.
LearningRateConfig
(
name
=
'piecewise_constant_with_warmup'
,
examples_per_epoch
=
1281167
,
warmup_epochs
=
_RESNET_LR_WARMUP_EPOCHS
,
boundaries
=
_RESNET_LR_BOUNDARIES
,
multipliers
=
_RESNET_LR_MULTIPLIERS
))
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
1f3247f4
...
@@ -119,8 +119,8 @@ def run(flags_obj):
...
@@ -119,8 +119,8 @@ def run(flags_obj):
# TODO(anj-s): Set data_format without using Keras.
# TODO(anj-s): Set data_format without using Keras.
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
official/vision/image_classification/resnet/resnet_imagenet_main.py
View file @
1f3247f4
...
@@ -71,8 +71,8 @@ def run(flags_obj):
...
@@ -71,8 +71,8 @@ def run(flags_obj):
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
data_format
=
(
'channels_first'
data_format
=
(
'channels_first'
if
tf
.
config
.
list_physical_devices
(
'GPU'
)
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
# Configures cluster spec for distribution strategy.
# Configures cluster spec for distribution strategy.
...
...
official/vision/image_classification/resnet/resnet_model.py
View file @
1f3247f4
...
@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
...
@@ -35,15 +35,11 @@ from tensorflow.python.keras import models
from
tensorflow.python.keras
import
regularizers
from
tensorflow.python.keras
import
regularizers
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
L2_WEIGHT_DECAY
=
1e-4
BATCH_NORM_DECAY
=
0.9
BATCH_NORM_EPSILON
=
1e-5
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
):
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
,
l2_weight_decay
=
1e-4
):
return
regularizers
.
l2
(
L
2_
WEIGHT_DECAY
)
if
use_l2_regularizer
else
None
return
regularizers
.
l2
(
l
2_
weight_decay
)
if
use_l2_regularizer
else
None
def
identity_block
(
input_tensor
,
def
identity_block
(
input_tensor
,
...
@@ -51,7 +47,9 @@ def identity_block(input_tensor,
...
@@ -51,7 +47,9 @@ def identity_block(input_tensor,
filters
,
filters
,
stage
,
stage
,
block
,
block
,
use_l2_regularizer
=
True
):
use_l2_regularizer
=
True
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""The identity block is the block that has no conv layer at shortcut.
"""The identity block is the block that has no conv layer at shortcut.
Args:
Args:
...
@@ -61,6 +59,8 @@ def identity_block(input_tensor,
...
@@ -61,6 +59,8 @@ def identity_block(input_tensor,
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -82,8 +82,8 @@ def identity_block(input_tensor,
...
@@ -82,8 +82,8 @@ def identity_block(input_tensor,
input_tensor
)
input_tensor
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2a'
)(
name
=
bn_name_base
+
'2a'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -99,8 +99,8 @@ def identity_block(input_tensor,
...
@@ -99,8 +99,8 @@ def identity_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2b'
)(
name
=
bn_name_base
+
'2b'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -114,8 +114,8 @@ def identity_block(input_tensor,
...
@@ -114,8 +114,8 @@ def identity_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2c'
)(
name
=
bn_name_base
+
'2c'
)(
x
)
x
)
...
@@ -130,7 +130,9 @@ def conv_block(input_tensor,
...
@@ -130,7 +130,9 @@ def conv_block(input_tensor,
stage
,
stage
,
block
,
block
,
strides
=
(
2
,
2
),
strides
=
(
2
,
2
),
use_l2_regularizer
=
True
):
use_l2_regularizer
=
True
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""A block that has a conv layer at shortcut.
"""A block that has a conv layer at shortcut.
Note that from stage 3,
Note that from stage 3,
...
@@ -145,6 +147,8 @@ def conv_block(input_tensor,
...
@@ -145,6 +147,8 @@ def conv_block(input_tensor,
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -166,8 +170,8 @@ def conv_block(input_tensor,
...
@@ -166,8 +170,8 @@ def conv_block(input_tensor,
input_tensor
)
input_tensor
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2a'
)(
name
=
bn_name_base
+
'2a'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -184,8 +188,8 @@ def conv_block(input_tensor,
...
@@ -184,8 +188,8 @@ def conv_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2b'
)(
name
=
bn_name_base
+
'2b'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -199,8 +203,8 @@ def conv_block(input_tensor,
...
@@ -199,8 +203,8 @@ def conv_block(input_tensor,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'2c'
)(
name
=
bn_name_base
+
'2c'
)(
x
)
x
)
...
@@ -214,8 +218,8 @@ def conv_block(input_tensor,
...
@@ -214,8 +218,8 @@ def conv_block(input_tensor,
input_tensor
)
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
bn_name_base
+
'1'
)(
name
=
bn_name_base
+
'1'
)(
shortcut
)
shortcut
)
...
@@ -227,7 +231,9 @@ def conv_block(input_tensor,
...
@@ -227,7 +231,9 @@ def conv_block(input_tensor,
def
resnet50
(
num_classes
,
def
resnet50
(
num_classes
,
batch_size
=
None
,
batch_size
=
None
,
use_l2_regularizer
=
True
,
use_l2_regularizer
=
True
,
rescale_inputs
=
False
):
rescale_inputs
=
False
,
batch_norm_decay
=
0.9
,
batch_norm_epsilon
=
1e-5
):
"""Instantiates the ResNet50 architecture.
"""Instantiates the ResNet50 architecture.
Args:
Args:
...
@@ -235,6 +241,8 @@ def resnet50(num_classes,
...
@@ -235,6 +241,8 @@ def resnet50(num_classes,
batch_size: Size of the batches for each step.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
rescale_inputs: whether to rescale inputs from 0 to 1.
rescale_inputs: whether to rescale inputs from 0 to 1.
batch_norm_decay: Moment of batch norm layers.
batch_norm_epsilon: Epsilon of batch borm layers.
Returns:
Returns:
A Keras model instance.
A Keras model instance.
...
@@ -260,6 +268,10 @@ def resnet50(num_classes,
...
@@ -260,6 +268,10 @@ def resnet50(num_classes,
else
:
# channels_last
else
:
# channels_last
bn_axis
=
3
bn_axis
=
3
block_config
=
dict
(
use_l2_regularizer
=
use_l2_regularizer
,
batch_norm_decay
=
batch_norm_decay
,
batch_norm_epsilon
=
batch_norm_epsilon
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
Conv2D
(
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
64
,
(
7
,
7
),
...
@@ -272,113 +284,33 @@ def resnet50(num_classes,
...
@@ -272,113 +284,33 @@ def resnet50(num_classes,
x
)
x
)
x
=
layers
.
BatchNormalization
(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
batch_norm_decay
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
batch_norm_epsilon
,
name
=
'bn_conv1'
)(
name
=
'bn_conv1'
)(
x
)
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
conv_block
(
x
=
conv_block
(
x
,
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
),
**
block_config
)
3
,
[
64
,
64
,
256
],
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
,
**
block_config
)
stage
=
2
,
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
,
**
block_config
)
block
=
'a'
,
strides
=
(
1
,
1
),
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
**
block_config
)
x
=
identity_block
(
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
**
block_config
)
x
,
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
**
block_config
)
3
,
[
64
,
64
,
256
],
stage
=
2
,
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
**
block_config
)
block
=
'b'
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
**
block_config
)
x
=
identity_block
(
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
**
block_config
)
x
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
**
block_config
)
3
,
[
64
,
64
,
256
],
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
**
block_config
)
stage
=
2
,
block
=
'c'
,
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
**
block_config
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
**
block_config
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
**
block_config
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
GlobalAveragePooling2D
()(
x
)
x
=
layers
.
Dense
(
x
=
layers
.
Dense
(
...
...
official/vision/image_classification/resnet/resnet_runnable.py
View file @
1f3247f4
...
@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -158,9 +158,9 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
loss
=
tf
.
reduce_sum
(
prediction_loss
)
*
(
1.0
/
self
.
flags_obj
.
batch_size
)
self
.
flags_obj
.
batch_size
)
num_replicas
=
self
.
strategy
.
num_replicas_in_sync
num_replicas
=
self
.
strategy
.
num_replicas_in_sync
l2_weight_decay
=
1e-4
if
self
.
flags_obj
.
single_l2_loss_op
:
if
self
.
flags_obj
.
single_l2_loss_op
:
l2_loss
=
resnet_model
.
L2_WEIGHT_DECAY
*
2
*
tf
.
add_n
([
l2_loss
=
l2_weight_decay
*
2
*
tf
.
add_n
([
tf
.
nn
.
l2_loss
(
v
)
tf
.
nn
.
l2_loss
(
v
)
for
v
in
self
.
model
.
trainable_variables
for
v
in
self
.
model
.
trainable_variables
if
'bn'
not
in
v
.
name
if
'bn'
not
in
v
.
name
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment