Commit 6d16ae2e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix channel_first layout for efficientnet.

PiperOrigin-RevId: 304281524
parent b55c9da0
...@@ -242,9 +242,6 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -242,9 +242,6 @@ def initialize(params: base_configs.ExperimentConfig,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) datasets_num_private_threads=params.runtime.dataset_num_private_threads)
performance.set_mixed_precision_policy(dataset_builder.dtype) performance.set_mixed_precision_policy(dataset_builder.dtype)
if dataset_builder.config.data_format:
data_format = dataset_builder.config.data_format
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
......
...@@ -264,7 +264,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -264,7 +264,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
fake_ds_builder = EmptyClass() fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass() fake_ds_builder.config = EmptyClass()
fake_ds_builder.config.data_format = None
classifier_trainer.initialize(config, fake_ds_builder) classifier_trainer.initialize(config, fake_ds_builder)
def test_resume_from_checkpoint(self): def test_resume_from_checkpoint(self):
......
...@@ -87,8 +87,6 @@ class DatasetConfig(base_config.Config): ...@@ -87,8 +87,6 @@ class DatasetConfig(base_config.Config):
(e.g., the number of GPUs or TPU cores). (e.g., the number of GPUs or TPU cores).
num_devices: The number of replica devices to use. This should be set by num_devices: The number of replica devices to use. This should be set by
`strategy.num_replicas_in_sync` when using a distribution strategy. `strategy.num_replicas_in_sync` when using a distribution strategy.
data_format: The data format of the images. Should be 'channels_last' or
'channels_first'.
dtype: The desired dtype of the dataset. This will be set during dtype: The desired dtype of the dataset. This will be set during
preprocessing. preprocessing.
one_hot: Whether to apply one hot encoding. Set to `True` to be able to use one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
...@@ -120,7 +118,6 @@ class DatasetConfig(base_config.Config): ...@@ -120,7 +118,6 @@ class DatasetConfig(base_config.Config):
batch_size: int = 128 batch_size: int = 128
use_per_replica_batch_size: bool = False use_per_replica_batch_size: bool = False
num_devices: int = 1 num_devices: int = 1
data_format: str = 'channels_last'
dtype: str = 'float32' dtype: str = 'float32'
one_hot: bool = True one_hot: bool = True
augmenter: AugmentConfig = AugmentConfig() augmenter: AugmentConfig = AugmentConfig()
......
...@@ -166,7 +166,7 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -166,7 +166,7 @@ def conv2d_block(inputs: tf.Tensor,
batch_norm = common_modules.get_batch_norm(config.batch_norm) batch_norm = common_modules.get_batch_norm(config.batch_norm)
bn_momentum = config.bn_momentum bn_momentum = config.bn_momentum
bn_epsilon = config.bn_epsilon bn_epsilon = config.bn_epsilon
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
weight_decay = config.weight_decay weight_decay = config.weight_decay
name = name or '' name = name or ''
...@@ -223,7 +223,7 @@ def mb_conv_block(inputs: tf.Tensor, ...@@ -223,7 +223,7 @@ def mb_conv_block(inputs: tf.Tensor,
use_se = config.use_se use_se = config.use_se
activation = tf_utils.get_activation(config.activation) activation = tf_utils.get_activation(config.activation)
drop_connect_rate = config.drop_connect_rate drop_connect_rate = config.drop_connect_rate
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
use_depthwise = block.conv_type != 'no_depthwise' use_depthwise = block.conv_type != 'no_depthwise'
prefix = prefix or '' prefix = prefix or ''
...@@ -346,12 +346,14 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -346,12 +346,14 @@ def efficientnet(image_input: tf.keras.layers.Input,
num_classes = config.num_classes num_classes = config.num_classes
input_channels = config.input_channels input_channels = config.input_channels
rescale_input = config.rescale_input rescale_input = config.rescale_input
data_format = config.data_format data_format = tf.keras.backend.image_data_format()
dtype = config.dtype dtype = config.dtype
weight_decay = config.weight_decay weight_decay = config.weight_decay
x = image_input x = image_input
if data_format == 'channels_first':
# Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input: if rescale_input:
x = preprocessing.normalize_images(x, x = preprocessing.normalize_images(x,
num_channels=input_channels, num_channels=input_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment