Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e2b671b5
Commit
e2b671b5
authored
Nov 18, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 18, 2021
Browse files
Internal change
PiperOrigin-RevId: 410840853
parent
9c0d7874
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
22 deletions
+47
-22
official/core/actions.py
official/core/actions.py
+13
-12
official/core/actions_test.py
official/core/actions_test.py
+34
-10
No files found.
official/core/actions.py
View file @
e2b671b5
...
@@ -28,7 +28,7 @@ from official.core import config_definitions
...
@@ -28,7 +28,7 @@ from official.core import config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
optimization
class
PruningAction
s
:
class
PruningAction
:
"""Train action to updates pruning related information.
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
This action updates pruning steps at the end of trainig loop, and log
...
@@ -66,7 +66,7 @@ class PruningActions:
...
@@ -66,7 +66,7 @@ class PruningActions:
"""Update pruning step and log pruning summaries.
"""Update pruning step and log pruning summaries.
Args:
Args:
output: The train output
to test
.
output: The train output.
"""
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
...
@@ -81,8 +81,11 @@ class EMACheckpointing:
...
@@ -81,8 +81,11 @@ class EMACheckpointing:
than training.
than training.
"""
"""
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
def
__init__
(
self
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
"""Initializes the instance.
"""Initializes the instance.
Args:
Args:
...
@@ -99,8 +102,7 @@ class EMACheckpointing:
...
@@ -99,8 +102,7 @@ class EMACheckpointing:
'EMACheckpointing action'
)
'EMACheckpointing action'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
tf
.
io
.
gfile
.
makedirs
(
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
os
.
path
.
dirname
(
export_dir
))
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
...
@@ -113,7 +115,7 @@ class EMACheckpointing:
...
@@ -113,7 +115,7 @@ class EMACheckpointing:
"""Swaps model weights, and saves the checkpoint.
"""Swaps model weights, and saves the checkpoint.
Args:
Args:
output: The train or eval output
to test
.
output: The train or eval output.
"""
"""
self
.
_optimizer
.
swap_weights
()
self
.
_optimizer
.
swap_weights
()
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
...
@@ -173,10 +175,9 @@ class RecoveryCondition:
...
@@ -173,10 +175,9 @@ class RecoveryCondition:
@
gin
.
configurable
@
gin
.
configurable
def
get_eval_actions
(
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets eval actions for TFM trainer."""
"""Gets eval actions for TFM trainer."""
eval_actions
=
[]
eval_actions
=
[]
# Adds ema checkpointing action to save the average weights under
# Adds ema checkpointing action to save the average weights under
...
@@ -202,7 +203,7 @@ def get_train_actions(
...
@@ -202,7 +203,7 @@ def get_train_actions(
# Adds pruning callback actions.
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
):
if
hasattr
(
params
.
task
,
'pruning'
):
train_actions
.
append
(
train_actions
.
append
(
PruningAction
s
(
PruningAction
(
export_dir
=
model_dir
,
export_dir
=
model_dir
,
model
=
trainer
.
model
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
optimizer
=
trainer
.
optimizer
))
...
...
official/core/actions_test.py
View file @
e2b671b5
...
@@ -27,14 +27,16 @@ from official.core import actions
...
@@ -27,14 +27,16 @@ from official.core import actions
from
official.modeling
import
optimization
from
official.modeling
import
optimization
class
TestModel
(
tf
.
Module
):
class
TestModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
value
=
tf
.
Variable
(
0
)
super
().
__init__
()
self
.
value
=
tf
.
Variable
(
0.0
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
_
=
self
.
dense
(
tf
.
zeros
((
2
,
2
),
tf
.
float32
))
@
tf
.
function
(
input_signature
=
[])
def
call
(
self
,
x
,
training
=
None
):
def
__call__
(
self
):
return
self
.
value
+
x
return
self
.
value
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
combinations
.
combine
(
combinations
.
combine
(
distribution
=
[
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy
_gpu
,
strategy_combinations
.
one_device_strategy
,
],))
],))
def
test_ema_checkpointing
(
self
,
distribution
):
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
with
distribution
.
scope
():
...
@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
model
.
value
.
assign
(
3
)
model
.
value
.
assign
(
3
)
# Checks model.value is 3
# Checks model.value is 3
self
.
assertEqual
(
model
(),
3
)
self
.
assertEqual
(
model
(
0.
),
3
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
({})
ema_action
({})
self
.
assertNotEmpty
(
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
checkpoint
.
read
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
tf
.
train
.
latest_checkpoint
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
# Checks model.value is 0 after swapping.
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
self
.
assertEqual
(
model
(
0.
),
0
)
# Raises an error for a normal optimizer.
with
self
.
assertRaisesRegex
(
ValueError
,
'Optimizer has to be instance of.*'
):
_
=
actions
.
EMACheckpointing
(
directory
,
tf
.
keras
.
optimizers
.
SGD
(),
checkpoint
)
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
...
@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
with
self
.
assertRaises
(
RuntimeError
):
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
recover_condition
(
outputs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
one_device_strategy
,
],))
def
test_pruning
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
get_temp_dir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
pruning
=
actions
.
PruningAction
(
directory
,
model
,
optimizer
)
pruning
({})
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment