Commit 428a156b authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 437879827
parent eaf21123
...@@ -237,20 +237,21 @@ class TrainerConfig(base_config.Config): ...@@ -237,20 +237,21 @@ class TrainerConfig(base_config.Config):
# we will retore the model states. # we will retore the model states.
recovery_max_trials: int = 0 recovery_max_trials: int = 0
validation_summary_subdir: str = "validation" validation_summary_subdir: str = "validation"
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
"""Config passed to task."""
init_checkpoint: str = "" init_checkpoint: str = ""
model: Optional[base_config.Config] = None model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
name: Optional[str] = None name: Optional[str] = None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -214,11 +214,15 @@ def create_optimizer(task: base_task.Task, ...@@ -214,11 +214,15 @@ def create_optimizer(task: base_task.Task,
) -> tf.keras.optimizers.Optimizer: ) -> tf.keras.optimizers.Optimizer:
"""A create optimizer util to be backward compatability with new args.""" """A create optimizer util to be backward compatability with new args."""
if 'dp_config' in inspect.signature(task.create_optimizer).parameters: if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
dp_config = None
if hasattr(params.task, 'differential_privacy_config'):
dp_config = params.task.differential_privacy_config
optimizer = task.create_optimizer( optimizer = task.create_optimizer(
params.trainer.optimizer_config, params.runtime, params.trainer.optimizer_config, params.runtime,
params.trainer.differential_privacy_config) dp_config=dp_config)
else: else:
if params.trainer.differential_privacy_config is not None: if hasattr(params.task, 'differential_privacy_config'
) and params.task.differential_privacy_config is not None:
raise ValueError('Differential privacy config is specified but ' raise ValueError('Differential privacy config is specified but '
'task.create_optimizer api does not accept it.') 'task.create_optimizer api does not accept it.')
optimizer = task.create_optimizer( optimizer = task.create_optimizer(
......
...@@ -19,6 +19,7 @@ import dataclasses ...@@ -19,6 +19,7 @@ import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling.privacy import configs as dp_configs
@dataclasses.dataclass @dataclasses.dataclass
...@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config): ...@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: hyperparams.Config = None model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = () task_routines: Tuple[TaskRoutine, ...] = ()
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment