Commit ed9b2039 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304468709
parent 34a3581c
...@@ -139,7 +139,8 @@ def get_loss_scale(params: base_configs.ExperimentConfig, ...@@ -139,7 +139,8 @@ def get_loss_scale(params: base_configs.ExperimentConfig,
return loss_scale return loss_scale
elif loss_scale is not None: elif loss_scale is not None:
return float(loss_scale) return float(loss_scale)
elif params.train_dataset.dtype == 'float32': elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1. return 1.
else: else:
assert params.train_dataset.dtype == 'float16' assert params.train_dataset.dtype == 'float16'
...@@ -241,7 +242,8 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -241,7 +242,8 @@ def initialize(params: base_configs.ExperimentConfig,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) datasets_num_private_threads=params.runtime.dataset_num_private_threads)
performance.set_mixed_precision_policy(dataset_builder.dtype) performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first' data_format = 'channels_first'
else: else:
......
...@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str: ...@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
return '--params_override=' + json.dumps(params_override) return '--params_override=' + json.dumps(params_override)
def basic_params_override() -> MutableMapping[str, Any]: def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing.""" """Returns a basic parameter configuration for testing."""
return { return {
'train_dataset': { 'train_dataset': {
...@@ -75,12 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]: ...@@ -75,12 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size': True, 'use_per_replica_batch_size': True,
'batch_size': 1, 'batch_size': 1,
'image_size': 224, 'image_size': 224,
'dtype': dtype,
}, },
'validation_dataset': { 'validation_dataset': {
'builder': 'synthetic', 'builder': 'synthetic',
'batch_size': 1, 'batch_size': 1,
'use_per_replica_batch_size': True, 'use_per_replica_batch_size': True,
'image_size': 224, 'image_size': 224,
'dtype': dtype,
}, },
'train': { 'train': {
'steps': 1, 'steps': 1,
...@@ -181,6 +183,89 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -181,6 +183,89 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir) model_dir=model_dir)
self.assertTrue(os.path.exists(export_path)) self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export')
export_params['export'] = {}
export_params['export']['destination'] = export_path
export_flags = base_flags + [
'--mode=export_only',
get_params_override(export_params)
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
run_end_to_end(main=run,
extra_flags=export_flags,
model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset): def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`.""" """Test the Keras EfficientNet model with `strategy`."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment