"examples/community/iadb.py" did not exist on "de22d4cd5d98727b2069275803933652573f0c73"
Commit 26ea4d1a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 306699912
parent 1fffb174
......@@ -101,8 +101,8 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy,
one_hot: bool
) -> Tuple[Any, Any, Any]:
"""Create and return train, validation, and test dataset builders."""
) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders."""
if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
else:
......
......@@ -116,7 +116,7 @@ class DatasetConfig(base_config.Config):
num_channels: Union[int, str] = 'infer'
num_examples: Union[int, str] = 'infer'
batch_size: int = 128
use_per_replica_batch_size: bool = False
use_per_replica_batch_size: bool = True
num_devices: int = 1
dtype: str = 'float32'
one_hot: bool = True
......@@ -185,14 +185,14 @@ class DatasetBuilder:
def batch_size(self) -> int:
"""The batch size, multiplied by the number of replicas (if configured)."""
if self.config.use_per_replica_batch_size:
return self.global_batch_size
return self.config.batch_size * self.config.num_devices
else:
return self.config.batch_size
@property
def global_batch_size(self):
"""The global batch size across all replicas."""
return self.config.batch_size * self.config.num_devices
return self.batch_size
@property
def num_steps(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment