Commit 26ea4d1a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306699912
parent 1fffb174
...@@ -101,8 +101,8 @@ def get_image_size_from_model( ...@@ -101,8 +101,8 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig, def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy, strategy: tf.distribute.Strategy,
one_hot: bool one_hot: bool
) -> Tuple[Any, Any, Any]: ) -> Tuple[Any, Any]:
"""Create and return train, validation, and test dataset builders.""" """Create and return train and validation dataset builders."""
if one_hot: if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.') logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else: else:
......
...@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config): ...@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
num_channels: Union[int, str] = 'infer' num_channels: Union[int, str] = 'infer'
num_examples: Union[int, str] = 'infer' num_examples: Union[int, str] = 'infer'
batch_size: int = 128 batch_size: int = 128
use_per_replica_batch_size: bool = False use_per_replica_batch_size: bool = True
num_devices: int = 1 num_devices: int = 1
dtype: str = 'float32' dtype: str = 'float32'
one_hot: bool = True one_hot: bool = True
...@@ -185,14 +185,14 @@ class DatasetBuilder: ...@@ -185,14 +185,14 @@ class DatasetBuilder:
def batch_size(self) -> int: def batch_size(self) -> int:
"""The batch size, multiplied by the number of replicas (if configured).""" """The batch size, multiplied by the number of replicas (if configured)."""
if self.config.use_per_replica_batch_size: if self.config.use_per_replica_batch_size:
return self.global_batch_size return self.config.batch_size * self.config.num_devices
else: else:
return self.config.batch_size return self.config.batch_size
@property @property
def global_batch_size(self): def global_batch_size(self):
"""The global batch size across all replicas.""" """The global batch size across all replicas."""
return self.config.batch_size * self.config.num_devices return self.batch_size
@property @property
def num_steps(self) -> int: def num_steps(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment