Commit 7d210ec0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Changing default dtype of Object detection to float32.

PiperOrigin-RevId: 317263038
parent c8f9cf19
...@@ -19,25 +19,28 @@ from __future__ import division ...@@ -19,25 +19,28 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
from absl import app
from absl import flags
from absl import logging
import functools import functools
import os
import pprint import pprint
# pylint: disable=g-bad-import-order
import tensorflow as tf import tensorflow as tf
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor from official.modeling.training import distributed_executor as executor
from official.utils import hyperparams_flags from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader from official.vision.detection.dataloader import input_reader
from official.vision.detection.dataloader import mode_keys as ModeKeys from official.vision.detection.dataloader import mode_keys as ModeKeys
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor
from official.vision.detection.modeling import factory as model_factory from official.vision.detection.modeling import factory as model_factory
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
flags_core.define_log_steps() flags_core.define_log_steps()
...@@ -194,6 +197,20 @@ def run(callbacks=None): ...@@ -194,6 +197,20 @@ def run(callbacks=None):
'strategy_config': executor.strategy_flags_dict(), 'strategy_config': executor.strategy_flags_dict(),
}, },
is_strict=False) is_strict=False)
# Make sure use_tpu and strategy_type are in sync.
params.use_tpu = (params.strategy_type == 'tpu')
if not params.use_tpu:
params.override({
'architecture': {
'use_bfloat16': False,
},
'norm_activation': {
'use_sync_bn': False,
},
}, is_strict=True)
params.validate() params.validate()
params.lock() params.lock()
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment