Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dc8c2fb4
Commit
dc8c2fb4
authored
Mar 24, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Mar 24, 2022
Browse files
Internal change
PiperOrigin-RevId: 437120060
parent
88d7510f
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
230 additions
and
10 deletions
+230
-10
official/core/base_task.py
official/core/base_task.py
+23
-2
official/core/config_definitions.py
official/core/config_definitions.py
+6
-0
official/core/train_utils.py
official/core/train_utils.py
+20
-2
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+6
-2
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+2
-4
official/modeling/privacy/__init__.py
official/modeling/privacy/__init__.py
+14
-0
official/modeling/privacy/configs.py
official/modeling/privacy/configs.py
+24
-0
official/modeling/privacy/configs_test.py
official/modeling/privacy/configs_test.py
+41
-0
official/modeling/privacy/ops.py
official/modeling/privacy/ops.py
+42
-0
official/modeling/privacy/ops_test.py
official/modeling/privacy/ops_test.py
+52
-0
No files found.
official/core/base_task.py
View file @
dc8c2fb4
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Defines the base task abstraction."""
"""Defines the base task abstraction."""
import
abc
import
abc
import
functools
from
typing
import
Optional
from
typing
import
Optional
from
absl
import
logging
from
absl
import
logging
...
@@ -22,9 +23,12 @@ import tensorflow as tf
...
@@ -22,9 +23,12 @@ import tensorflow as tf
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.privacy
import
configs
from
official.modeling.privacy
import
ops
OptimizationConfig
=
optimization
.
OptimizationConfig
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
DifferentialPrivacyConfig
=
configs
.
DifferentialPrivacyConfig
class
Task
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
Task
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
...
@@ -65,18 +69,35 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
...
@@ -65,18 +69,35 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
@
classmethod
@
classmethod
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
,
dp_config
:
Optional
[
DifferentialPrivacyConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
"""Creates an TF optimizer from configurations.
Args:
Args:
optimizer_config: the parameters of the Optimization settings.
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
runtime_config: the parameters of the runtime.
dp_config: the parameter of differential privacy.
Returns:
Returns:
A tf.optimizers.Optimizer object.
A tf.optimizers.Optimizer object.
"""
"""
gradient_transformers
=
None
if
dp_config
is
not
None
:
logging
.
info
(
"Adding differential privacy transform with config %s."
,
dp_config
.
as_dict
())
noise_stddev
=
dp_config
.
clipping_norm
*
dp_config
.
noise_multiplier
gradient_transformers
=
[
functools
.
partial
(
ops
.
clip_l2_norm
,
l2_norm_clip
=
dp_config
.
clipping_norm
),
functools
.
partial
(
ops
.
add_noise
,
noise_stddev
=
noise_stddev
)
]
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
(),
gradient_transformers
=
gradient_transformers
)
# Configuring optimizer when loss_scale is set in runtime config. This helps
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
# avoiding overflow/underflow for float16 computations.
if
runtime_config
:
if
runtime_config
:
...
...
official/core/config_definitions.py
View file @
dc8c2fb4
...
@@ -19,6 +19,7 @@ from typing import Optional, Sequence, Union
...
@@ -19,6 +19,7 @@ from typing import Optional, Sequence, Union
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.privacy
import
configs
as
dp_configs
OptimizationConfig
=
optimization_config
.
OptimizationConfig
OptimizationConfig
=
optimization_config
.
OptimizationConfig
...
@@ -236,6 +237,11 @@ class TrainerConfig(base_config.Config):
...
@@ -236,6 +237,11 @@ class TrainerConfig(base_config.Config):
# we will retore the model states.
# we will retore the model states.
recovery_max_trials
:
int
=
0
recovery_max_trials
:
int
=
0
validation_summary_subdir
:
str
=
"validation"
validation_summary_subdir
:
str
=
"validation"
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/core/train_utils.py
View file @
dc8c2fb4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Training utils."""
"""Training utils."""
import
copy
import
copy
import
dataclasses
import
dataclasses
import
inspect
import
json
import
json
import
os
import
os
import
pprint
import
pprint
...
@@ -208,6 +209,24 @@ class BestCheckpointExporter:
...
@@ -208,6 +209,24 @@ class BestCheckpointExporter:
return
tf
.
train
.
latest_checkpoint
(
self
.
_export_dir
)
return
tf
.
train
.
latest_checkpoint
(
self
.
_export_dir
)
def
create_optimizer
(
task
:
base_task
.
Task
,
params
:
config_definitions
.
ExperimentConfig
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""A create optimizer util to be backward compatability with new args."""
if
'dp_config'
in
inspect
.
signature
(
task
.
create_optimizer
).
parameters
:
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
,
params
.
trainer
.
differential_privacy_config
)
else
:
if
params
.
trainer
.
differential_privacy_config
is
not
None
:
raise
ValueError
(
'Differential privacy config is specified but '
'task.create_optimizer api does not accept it.'
)
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
)
return
optimizer
@
gin
.
configurable
@
gin
.
configurable
def
create_trainer
(
params
:
config_definitions
.
ExperimentConfig
,
def
create_trainer
(
params
:
config_definitions
.
ExperimentConfig
,
task
:
base_task
.
Task
,
task
:
base_task
.
Task
,
...
@@ -218,8 +237,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
...
@@ -218,8 +237,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer."""
"""Create trainer."""
logging
.
info
(
'Running default trainer.'
)
logging
.
info
(
'Running default trainer.'
)
model
=
task
.
build_model
()
model
=
task
.
build_model
()
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
optimizer
=
create_optimizer
(
task
,
params
)
params
.
runtime
)
return
trainer_cls
(
return
trainer_cls
(
params
,
params
,
task
,
task
,
...
...
official/modeling/multitask/multitask.py
View file @
dc8c2fb4
...
@@ -23,9 +23,11 @@ from official.core import task_factory
...
@@ -23,9 +23,11 @@ from official.core import task_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
configs
from
official.modeling.privacy
import
configs
as
dp_configs
OptimizationConfig
=
optimization
.
OptimizationConfig
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
DifferentialPrivacyConfig
=
dp_configs
.
DifferentialPrivacyConfig
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
...
@@ -93,9 +95,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -93,9 +95,11 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
@
classmethod
@
classmethod
def
create_optimizer
(
cls
,
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
,
dp_config
:
Optional
[
DifferentialPrivacyConfig
]
=
None
):
return
base_task
.
Task
.
create_optimizer
(
return
base_task
.
Task
.
create_optimizer
(
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
,
dp_config
=
dp_config
)
def
joint_train_step
(
self
,
task_inputs
,
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
...
...
official/modeling/multitask/train_lib.py
View file @
dc8c2fb4
...
@@ -66,8 +66,7 @@ def run_experiment(
...
@@ -66,8 +66,7 @@ def run_experiment(
is_training
=
'train'
in
mode
is_training
=
'train'
in
mode
is_eval
=
'eval'
in
mode
is_eval
=
'eval'
in
mode
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
optimizer
=
train_utils
.
create_optimizer
(
task
,
params
)
params
.
runtime
)
kwargs
=
dict
(
multi_task
=
task
,
multi_task_model
=
model
,
optimizer
=
optimizer
)
kwargs
=
dict
(
multi_task
=
task
,
multi_task_model
=
model
,
optimizer
=
optimizer
)
if
params
.
trainer
.
trainer_type
==
'interleaving'
:
if
params
.
trainer
.
trainer_type
==
'interleaving'
:
sampler
=
task_sampler
.
get_task_sampler
(
params
.
trainer
.
task_sampler
,
sampler
=
task_sampler
.
get_task_sampler
(
params
.
trainer
.
task_sampler
,
...
@@ -183,8 +182,7 @@ def run_experiment_with_multitask_eval(
...
@@ -183,8 +182,7 @@ def run_experiment_with_multitask_eval(
config
=
params
,
config
=
params
,
task
=
train_task
,
task
=
train_task
,
model
=
train_task
.
build_model
(),
model
=
train_task
.
build_model
(),
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
optimizer
=
train_utils
.
create_optimizer
(
train_task
,
params
),
params
.
runtime
),
train
=
True
,
train
=
True
,
evaluate
=
False
)
evaluate
=
False
)
else
:
else
:
...
...
official/modeling/privacy/__init__.py
0 → 100644
View file @
dc8c2fb4
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/modeling/privacy/configs.py
0 → 100644
View file @
dc8c2fb4
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configs for differential privacy."""
from
official.modeling.hyperparams
import
base_config
class
DifferentialPrivacyConfig
(
base_config
.
Config
):
# Applied to the gradients
# Setting to a large number so nothing is clipped.
clipping_norm
:
float
=
100000000.0
# 10^9
noise_multiplier
:
float
=
0.0
official/modeling/privacy/configs_test.py
0 → 100644
View file @
dc8c2fb4
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for configs."""
import
tensorflow
as
tf
from
official.modeling.privacy
import
configs
class
ConfigsTest
(
tf
.
test
.
TestCase
):
def
test_clipping_norm_default
(
self
):
clipping_norm
=
configs
.
DifferentialPrivacyConfig
().
clipping_norm
self
.
assertEqual
(
100000000.0
,
clipping_norm
)
def
test_noise_multiplier_default
(
self
):
noise_multiplier
=
configs
.
DifferentialPrivacyConfig
().
noise_multiplier
self
.
assertEqual
(
0.0
,
noise_multiplier
)
def
test_config
(
self
):
dp_config
=
configs
.
DifferentialPrivacyConfig
({
'clipping_norm'
:
1.0
,
'noise_multiplier'
:
1.0
})
self
.
assertEqual
(
1.0
,
dp_config
.
clipping_norm
)
self
.
assertEqual
(
1.0
,
dp_config
.
noise_multiplier
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/privacy/ops.py
0 → 100644
View file @
dc8c2fb4
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ops for differential privacy (gradient) transforms."""
from
typing
import
List
,
Tuple
import
tensorflow
as
tf
def
clip_l2_norm
(
grads_vars
:
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]],
l2_norm_clip
:
float
)
->
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]:
"""Clip gradients by global norm."""
gradients
=
[]
variables
=
[]
for
(
g
,
v
)
in
grads_vars
:
gradients
.
append
(
g
)
variables
.
append
(
v
)
clipped_gradients
=
tf
.
clip_by_global_norm
(
gradients
,
l2_norm_clip
)[
0
]
return
list
(
zip
(
clipped_gradients
,
variables
))
def
add_noise
(
grads_vars
:
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]],
noise_stddev
:
float
)
->
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]:
"""Add noise to gradients."""
ret
=
[]
for
(
g
,
v
)
in
grads_vars
:
noise
=
tf
.
random
.
normal
(
tf
.
shape
(
g
),
stddev
=
noise_stddev
)
ret
.
append
((
g
+
noise
,
v
))
return
ret
official/modeling/privacy/ops_test.py
0 → 100644
View file @
dc8c2fb4
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ops."""
from
unittest
import
mock
import
tensorflow
as
tf
from
official.modeling.privacy
import
ops
class
OpsTest
(
tf
.
test
.
TestCase
):
def
test_clip_l2_norm
(
self
):
x
=
tf
.
constant
([
4.0
,
3.0
])
y
=
tf
.
constant
([[
12.0
]])
tensors
=
[(
x
,
x
),
(
y
,
y
)]
clipped
=
ops
.
clip_l2_norm
(
tensors
,
1.0
)
for
a
,
b
in
zip
(
clipped
,
tensors
):
self
.
assertAllClose
(
a
[
0
],
b
[
0
]
/
13.0
)
# sqrt(4^2 + 3^2 + 12 ^3) = 13
self
.
assertAllClose
(
a
[
1
],
b
[
1
])
@
mock
.
patch
.
object
(
tf
.
random
,
'normal'
,
autospec
=
True
)
def
test_add_noise
(
self
,
mock_random
):
x
=
tf
.
constant
([
0.0
,
0.0
])
y
=
tf
.
constant
([[
0.0
]])
tensors
=
[(
x
,
x
),
(
y
,
y
)]
mock_random
.
side_effect
=
[
tf
.
constant
([
1.0
,
1.0
]),
tf
.
constant
([[
1.0
]])]
added
=
ops
.
add_noise
(
tensors
,
10.0
)
for
a
,
b
in
zip
(
added
,
tensors
):
self
.
assertAllClose
(
a
[
0
],
b
[
0
]
+
1.0
)
self
.
assertAllClose
(
a
[
1
],
b
[
1
])
_
,
kwargs
=
mock_random
.
call_args
self
.
assertEqual
(
kwargs
[
'stddev'
],
10.0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment