Commit ed9b2039 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304468709
parent 34a3581c
......@@ -139,7 +139,8 @@ def get_loss_scale(params: base_configs.ExperimentConfig,
return loss_scale
elif loss_scale is not None:
return float(loss_scale)
elif params.train_dataset.dtype == 'float32':
elif (params.train_dataset.dtype == 'float32' or
params.train_dataset.dtype == 'bfloat16'):
return 1.
else:
assert params.train_dataset.dtype == 'float16'
......@@ -241,7 +242,8 @@ def initialize(params: base_configs.ExperimentConfig,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads)
performance.set_mixed_precision_policy(dataset_builder.dtype)
performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
if tf.config.list_physical_devices('GPU'):
data_format = 'channels_first'
else:
......
......@@ -67,7 +67,7 @@ def get_params_override(params_override: Mapping[str, Any]) -> str:
return '--params_override=' + json.dumps(params_override)
def basic_params_override() -> MutableMapping[str, Any]:
def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
"""Returns a basic parameter configuration for testing."""
return {
'train_dataset': {
......@@ -75,12 +75,14 @@ def basic_params_override() -> MutableMapping[str, Any]:
'use_per_replica_batch_size': True,
'batch_size': 1,
'image_size': 224,
'dtype': dtype,
},
'validation_dataset': {
'builder': 'synthetic',
'batch_size': 1,
'use_per_replica_batch_size': True,
'image_size': 224,
'dtype': dtype,
},
'train': {
'steps': 1,
......@@ -181,6 +183,89 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='float16',
))
def test_gpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
export_params = basic_params_override()
export_path = os.path.join(model_dir, 'export')
export_params['export'] = {}
export_params['export']['destination'] = export_path
export_flags = base_flags + [
'--mode=export_only',
get_params_override(export_params)
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
run_end_to_end(main=run,
extra_flags=export_flags,
model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# classifier_train.py --batch_size=...) by design, so use
# "--params_override=..." instead
model_dir = self.get_temp_dir()
base_flags = [
'--data_dir=not_used',
'--model_type=' + model,
'--dataset=' + dataset,
]
train_and_eval_flags = base_flags + [
get_params_override(basic_params_override(dtype)),
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset):
"""Test the Keras EfficientNet model with `strategy`."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment