Commit 1792fb76 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300426073
parent 7bf81db8
......@@ -24,7 +24,6 @@ import copy
import functools
from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses
import tensorflow as tf
import yaml
......@@ -74,8 +73,8 @@ class Config(params_dict.ParamsDict):
"""Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found %r',
subconfig_type)
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
......@@ -95,7 +94,7 @@ class Config(params_dict.ParamsDict):
elif isinstance(v, dict):
return subconfig_type(v)
else:
raise TypeError('Unknown type: %r' % type(v))
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _export_config(cls, v):
......@@ -162,7 +161,9 @@ class Config(params_dict.ParamsDict):
"""
subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
if k not in self.__dict__:
if k not in self.__dict__ or not self.__dict__[k]:
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
......@@ -193,15 +194,16 @@ class Config(params_dict.ParamsDict):
'Can not be overridden.'.format(k))
if k not in self.__dict__:
if is_strict:
raise KeyError('The key {!r} does not exist. '
raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k))
'`override` with `is_strict` = False.'.format(
k, type(self)))
else:
self._set(k, v)
else:
if isinstance(v, dict):
if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict):
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self._set(k, v)
......@@ -268,6 +270,8 @@ class RuntimeConfig(Config):
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
"""
distribution_strategy: str = 'mirrored'
enable_eager: bool = False
......@@ -281,6 +285,7 @@ class RuntimeConfig(Config):
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
@dataclasses.dataclass
......@@ -311,4 +316,3 @@ class CallbacksConfig(Config):
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment