Commit 1792fb76 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300426073
parent 7bf81db8
...@@ -24,7 +24,6 @@ import copy ...@@ -24,7 +24,6 @@ import copy
import functools import functools
from typing import Any, List, Mapping, Optional, Type from typing import Any, List, Mapping, Optional, Type
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
import yaml import yaml
...@@ -74,8 +73,8 @@ class Config(params_dict.ParamsDict): ...@@ -74,8 +73,8 @@ class Config(params_dict.ParamsDict):
"""Returns v with dicts converted to Configs, recursively.""" """Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict): if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError( raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found %r', 'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type) subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES): if isinstance(v, cls.IMMUTABLE_TYPES):
return v return v
elif isinstance(v, cls.SEQUENCE_TYPES): elif isinstance(v, cls.SEQUENCE_TYPES):
...@@ -95,7 +94,7 @@ class Config(params_dict.ParamsDict): ...@@ -95,7 +94,7 @@ class Config(params_dict.ParamsDict):
elif isinstance(v, dict): elif isinstance(v, dict):
return subconfig_type(v) return subconfig_type(v)
else: else:
raise TypeError('Unknown type: %r' % type(v)) raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod @classmethod
def _export_config(cls, v): def _export_config(cls, v):
...@@ -162,7 +161,9 @@ class Config(params_dict.ParamsDict): ...@@ -162,7 +161,9 @@ class Config(params_dict.ParamsDict):
""" """
subconfig_type = self._get_subconfig_type(k) subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict): if isinstance(v, dict):
if k not in self.__dict__: if k not in self.__dict__ or not self.__dict__[k]:
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v) self.__dict__[k] = subconfig_type(v)
else: else:
self.__dict__[k].override(v) self.__dict__[k].override(v)
...@@ -193,15 +194,16 @@ class Config(params_dict.ParamsDict): ...@@ -193,15 +194,16 @@ class Config(params_dict.ParamsDict):
'Can not be overridden.'.format(k)) 'Can not be overridden.'.format(k))
if k not in self.__dict__: if k not in self.__dict__:
if is_strict: if is_strict:
raise KeyError('The key {!r} does not exist. ' raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use ' 'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k)) '`override` with `is_strict` = False.'.format(
k, type(self)))
else: else:
self._set(k, v) self._set(k, v)
else: else:
if isinstance(v, dict): if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict): elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else: else:
self._set(k, v) self._set(k, v)
...@@ -268,6 +270,8 @@ class RuntimeConfig(Config): ...@@ -268,6 +270,8 @@ class RuntimeConfig(Config):
multi-worker models with DistributionStrategy. multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker. task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce. all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
""" """
distribution_strategy: str = 'mirrored' distribution_strategy: str = 'mirrored'
enable_eager: bool = False enable_eager: bool = False
...@@ -281,6 +285,7 @@ class RuntimeConfig(Config): ...@@ -281,6 +285,7 @@ class RuntimeConfig(Config):
worker_hosts: Optional[str] = None worker_hosts: Optional[str] = None
task_index: int = -1 task_index: int = -1
all_reduce_alg: Optional[str] = None all_reduce_alg: Optional[str] = None
num_packs: int = 1
@dataclasses.dataclass @dataclasses.dataclass
...@@ -311,4 +316,3 @@ class CallbacksConfig(Config): ...@@ -311,4 +316,3 @@ class CallbacksConfig(Config):
""" """
enable_checkpoint_and_export: bool = True enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True enable_tensorboard: bool = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment