Commit 1ac65814 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Moves some common config definition to modeliing/hyperparams/.

PiperOrigin-RevId: 298652484
parent f16b1cf3
...@@ -246,3 +246,71 @@ class Config(params_dict.ParamsDict): ...@@ -246,3 +246,71 @@ class Config(params_dict.ParamsDict):
default_params = {a: p for a, p in zip(attributes, args)} default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs) default_params.update(kwargs)
return cls(default_params) return cls(default_params)
@dataclasses.dataclass
class RuntimeConfig(Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
"""
distribution_strategy: str = 'mirrored'
enable_eager: bool = False
enable_xla: bool = False
gpu_threads_enabled: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
@dataclasses.dataclass
class TensorboardConfig(Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as
images in Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment