Commit 49b58967 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 314670076
parent 9a0986d1
......@@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from typing import Optional
from typing import Optional, Union
import dataclasses
......@@ -80,8 +80,10 @@ class RuntimeConfig(base_config.Config):
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
loss_scale: The type of loss scale. This is used when setting the mixed
precision policy.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
......@@ -97,7 +99,8 @@ class RuntimeConfig(base_config.Config):
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
loss_scale: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
mixed_precision_dtype: Optional[str] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment