Commit 8138d8b9 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 437879827
parent 58798b7a
......@@ -237,20 +237,21 @@ class TrainerConfig(base_config.Config):
# we will retore the model states.
recovery_max_trials: int = 0
validation_summary_subdir: str = "validation"
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass
class TaskConfig(base_config.Config):
"""Config passed to task."""
init_checkpoint: str = ""
model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
name: Optional[str] = None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass
......
......@@ -214,11 +214,15 @@ def create_optimizer(task: base_task.Task,
) -> tf.keras.optimizers.Optimizer:
"""A create optimizer util to be backward compatability with new args."""
if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
dp_config = None
if hasattr(params.task, 'differential_privacy_config'):
dp_config = params.task.differential_privacy_config
optimizer = task.create_optimizer(
params.trainer.optimizer_config, params.runtime,
params.trainer.differential_privacy_config)
dp_config=dp_config)
else:
if params.trainer.differential_privacy_config is not None:
if hasattr(params.task, 'differential_privacy_config'
) and params.task.differential_privacy_config is not None:
raise ValueError('Differential privacy config is specified but '
'task.create_optimizer api does not accept it.')
optimizer = task.create_optimizer(
......
......@@ -19,6 +19,7 @@ import dataclasses
from official.core import config_definitions as cfg
from official.modeling import hyperparams
from official.modeling.privacy import configs as dp_configs
@dataclasses.dataclass
......@@ -35,6 +36,8 @@ class MultiTaskConfig(hyperparams.Config):
init_checkpoint: str = ""
model: hyperparams.Config = None
task_routines: Tuple[TaskRoutine, ...] = ()
differential_privacy_config: Optional[
dp_configs.DifferentialPrivacyConfig] = None
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment