Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
428a156b
Commit
428a156b
authored
Mar 28, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Mar 28, 2022
Browse files
Internal change
PiperOrigin-RevId: 437879827
parent
eaf21123
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
7 deletions
+15
-7
official/core/config_definitions.py
official/core/config_definitions.py
+6
-5
official/core/train_utils.py
official/core/train_utils.py
+6
-2
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+3
-0
No files found.
official/core/config_definitions.py
View file @
428a156b
...
@@ -237,20 +237,21 @@ class TrainerConfig(base_config.Config):
...
@@ -237,20 +237,21 @@ class TrainerConfig(base_config.Config):
# we will retore the model states.
# we will retore the model states.
recovery_max_trials
:
int
=
0
recovery_max_trials
:
int
=
0
validation_summary_subdir
:
str
=
"validation"
validation_summary_subdir
:
str
=
"validation"
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
class
TaskConfig
(
base_config
.
Config
):
"""Config passed to task."""
init_checkpoint
:
str
=
""
init_checkpoint
:
str
=
""
model
:
Optional
[
base_config
.
Config
]
=
None
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/core/train_utils.py
View file @
428a156b
...
@@ -214,11 +214,15 @@ def create_optimizer(task: base_task.Task,
...
@@ -214,11 +214,15 @@ def create_optimizer(task: base_task.Task,
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""A create optimizer util to be backward compatability with new args."""
"""A create optimizer util to be backward compatability with new args."""
if
'dp_config'
in
inspect
.
signature
(
task
.
create_optimizer
).
parameters
:
if
'dp_config'
in
inspect
.
signature
(
task
.
create_optimizer
).
parameters
:
dp_config
=
None
if
hasattr
(
params
.
task
,
'differential_privacy_config'
):
dp_config
=
params
.
task
.
differential_privacy_config
optimizer
=
task
.
create_optimizer
(
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
,
params
.
trainer
.
optimizer_config
,
params
.
runtime
,
params
.
trainer
.
differential_privacy
_config
)
dp_config
=
dp
_config
)
else
:
else
:
if
params
.
trainer
.
differential_privacy_config
is
not
None
:
if
hasattr
(
params
.
task
,
'differential_privacy_config'
)
and
params
.
task
.
differential_privacy_config
is
not
None
:
raise
ValueError
(
'Differential privacy config is specified but '
raise
ValueError
(
'Differential privacy config is specified but '
'task.create_optimizer api does not accept it.'
)
'task.create_optimizer api does not accept it.'
)
optimizer
=
task
.
create_optimizer
(
optimizer
=
task
.
create_optimizer
(
...
...
official/modeling/multitask/configs.py
View file @
428a156b
...
@@ -19,6 +19,7 @@ import dataclasses
...
@@ -19,6 +19,7 @@ import dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling.privacy
import
configs
as
dp_configs
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config):
...
@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config):
init_checkpoint
:
str
=
""
init_checkpoint
:
str
=
""
model
:
hyperparams
.
Config
=
None
model
:
hyperparams
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment