Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6d16ae2e
Commit
6d16ae2e
authored
Apr 01, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 01, 2020
Browse files
Fix channel_first layout for efficientnet.
PiperOrigin-RevId: 304281524
parent
b55c9da0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
11 deletions
+6
-11
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+0
-3
official/vision/image_classification/classifier_trainer_test.py
...al/vision/image_classification/classifier_trainer_test.py
+0
-1
official/vision/image_classification/dataset_factory.py
official/vision/image_classification/dataset_factory.py
+0
-3
official/vision/image_classification/efficientnet/efficientnet_model.py
...n/image_classification/efficientnet/efficientnet_model.py
+6
-4
No files found.
official/vision/image_classification/classifier_trainer.py
View file @
6d16ae2e
...
@@ -242,9 +242,6 @@ def initialize(params: base_configs.ExperimentConfig,
...
@@ -242,9 +242,6 @@ def initialize(params: base_configs.ExperimentConfig,
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
performance
.
set_mixed_precision_policy
(
dataset_builder
.
dtype
)
performance
.
set_mixed_precision_policy
(
dataset_builder
.
dtype
)
if
dataset_builder
.
config
.
data_format
:
data_format
=
dataset_builder
.
config
.
data_format
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
data_format
=
'channels_first'
data_format
=
'channels_first'
else
:
else
:
...
...
official/vision/image_classification/classifier_trainer_test.py
View file @
6d16ae2e
...
@@ -264,7 +264,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -264,7 +264,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
config
=
EmptyClass
()
fake_ds_builder
.
config
=
EmptyClass
()
fake_ds_builder
.
config
.
data_format
=
None
classifier_trainer
.
initialize
(
config
,
fake_ds_builder
)
classifier_trainer
.
initialize
(
config
,
fake_ds_builder
)
def
test_resume_from_checkpoint
(
self
):
def
test_resume_from_checkpoint
(
self
):
...
...
official/vision/image_classification/dataset_factory.py
View file @
6d16ae2e
...
@@ -87,8 +87,6 @@ class DatasetConfig(base_config.Config):
...
@@ -87,8 +87,6 @@ class DatasetConfig(base_config.Config):
(e.g., the number of GPUs or TPU cores).
(e.g., the number of GPUs or TPU cores).
num_devices: The number of replica devices to use. This should be set by
num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy.
`strategy.num_replicas_in_sync` when using a distribution strategy.
data_format: The data format of the images. Should be 'channels_last' or
'channels_first'.
dtype: The desired dtype of the dataset. This will be set during
dtype: The desired dtype of the dataset. This will be set during
preprocessing.
preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
...
@@ -120,7 +118,6 @@ class DatasetConfig(base_config.Config):
...
@@ -120,7 +118,6 @@ class DatasetConfig(base_config.Config):
batch_size
:
int
=
128
batch_size
:
int
=
128
use_per_replica_batch_size
:
bool
=
False
use_per_replica_batch_size
:
bool
=
False
num_devices
:
int
=
1
num_devices
:
int
=
1
data_format
:
str
=
'channels_last'
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
one_hot
:
bool
=
True
one_hot
:
bool
=
True
augmenter
:
AugmentConfig
=
AugmentConfig
()
augmenter
:
AugmentConfig
=
AugmentConfig
()
...
...
official/vision/image_classification/efficientnet/efficientnet_model.py
View file @
6d16ae2e
...
@@ -166,7 +166,7 @@ def conv2d_block(inputs: tf.Tensor,
...
@@ -166,7 +166,7 @@ def conv2d_block(inputs: tf.Tensor,
batch_norm
=
common_modules
.
get_batch_norm
(
config
.
batch_norm
)
batch_norm
=
common_modules
.
get_batch_norm
(
config
.
batch_norm
)
bn_momentum
=
config
.
bn_momentum
bn_momentum
=
config
.
bn_momentum
bn_epsilon
=
config
.
bn_epsilon
bn_epsilon
=
config
.
bn_epsilon
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
weight_decay
=
config
.
weight_decay
weight_decay
=
config
.
weight_decay
name
=
name
or
''
name
=
name
or
''
...
@@ -223,7 +223,7 @@ def mb_conv_block(inputs: tf.Tensor,
...
@@ -223,7 +223,7 @@ def mb_conv_block(inputs: tf.Tensor,
use_se
=
config
.
use_se
use_se
=
config
.
use_se
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
activation
=
tf_utils
.
get_activation
(
config
.
activation
)
drop_connect_rate
=
config
.
drop_connect_rate
drop_connect_rate
=
config
.
drop_connect_rate
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
use_depthwise
=
block
.
conv_type
!=
'no_depthwise'
use_depthwise
=
block
.
conv_type
!=
'no_depthwise'
prefix
=
prefix
or
''
prefix
=
prefix
or
''
...
@@ -346,12 +346,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -346,12 +346,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
num_classes
=
config
.
num_classes
num_classes
=
config
.
num_classes
input_channels
=
config
.
input_channels
input_channels
=
config
.
input_channels
rescale_input
=
config
.
rescale_input
rescale_input
=
config
.
rescale_input
data_format
=
config
.
data_format
data_format
=
tf
.
keras
.
backend
.
image_
data_format
()
dtype
=
config
.
dtype
dtype
=
config
.
dtype
weight_decay
=
config
.
weight_decay
weight_decay
=
config
.
weight_decay
x
=
image_input
x
=
image_input
if
data_format
==
'channels_first'
:
# Happens on GPU/TPU if available.
x
=
tf
.
keras
.
layers
.
Permute
((
3
,
1
,
2
))(
x
)
if
rescale_input
:
if
rescale_input
:
x
=
preprocessing
.
normalize_images
(
x
,
x
=
preprocessing
.
normalize_images
(
x
,
num_channels
=
input_channels
,
num_channels
=
input_channels
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment