Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1ac65814
Commit
1ac65814
authored
Mar 03, 2020
by
Yeqing Li
Committed by
A. Unique TensorFlower
Mar 03, 2020
Browse files
Moves some common config definition to modeliing/hyperparams/.
PiperOrigin-RevId: 298652484
parent
f16b1cf3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
0 deletions
+68
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+68
-0
No files found.
official/modeling/hyperparams/base_config.py
View file @
1ac65814
...
@@ -246,3 +246,71 @@ class Config(params_dict.ParamsDict):
...
@@ -246,3 +246,71 @@ class Config(params_dict.ParamsDict):
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
.
update
(
kwargs
)
default_params
.
update
(
kwargs
)
return
cls
(
default_params
)
return
cls
(
default_params
)
@
dataclasses
.
dataclass
class
RuntimeConfig
(
Config
):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
"""
distribution_strategy
:
str
=
'mirrored'
enable_eager
:
bool
=
False
enable_xla
:
bool
=
False
gpu_threads_enabled
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
tpu
:
Optional
[
str
]
=
None
num_gpus
:
int
=
0
worker_hosts
:
Optional
[
str
]
=
None
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
class
TensorboardConfig
(
Config
):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as
images in Tensorboard. Defaults to False.
"""
track_lr
:
bool
=
True
write_model_weights
:
bool
=
False
@
dataclasses
.
dataclass
class
CallbacksConfig
(
Config
):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
"""
enable_checkpoint_and_export
:
bool
=
True
enable_tensorboard
:
bool
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment