Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e2b671b5
Commit
e2b671b5
authored
Nov 18, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 18, 2021
Browse files
Internal change
PiperOrigin-RevId: 410840853
parent
9c0d7874
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
22 deletions
+47
-22
official/core/actions.py
official/core/actions.py
+13
-12
official/core/actions_test.py
official/core/actions_test.py
+34
-10
No files found.
official/core/actions.py
View file @
e2b671b5
...
...
@@ -28,7 +28,7 @@ from official.core import config_definitions
from
official.modeling
import
optimization
class
PruningAction
s
:
class
PruningAction
:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
...
...
@@ -66,7 +66,7 @@ class PruningActions:
"""Update pruning step and log pruning summaries.
Args:
output: The train output
to test
.
output: The train output.
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
...
...
@@ -81,8 +81,11 @@ class EMACheckpointing:
than training.
"""
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
"""Initializes the instance.
Args:
...
...
@@ -99,8 +102,7 @@ class EMACheckpointing:
'EMACheckpointing action'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
self
.
_optimizer
=
optimizer
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
...
...
@@ -113,7 +115,7 @@ class EMACheckpointing:
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output
to test
.
output: The train or eval output.
"""
self
.
_optimizer
.
swap_weights
()
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
...
...
@@ -173,8 +175,7 @@ class RecoveryCondition:
@
gin
.
configurable
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets eval actions for TFM trainer."""
...
...
@@ -202,7 +203,7 @@ def get_train_actions(
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
):
train_actions
.
append
(
PruningAction
s
(
PruningAction
(
export_dir
=
model_dir
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
...
...
official/core/actions_test.py
View file @
e2b671b5
...
...
@@ -27,14 +27,16 @@ from official.core import actions
from
official.modeling
import
optimization
class
TestModel
(
tf
.
Module
):
class
TestModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
):
self
.
value
=
tf
.
Variable
(
0
)
super
().
__init__
()
self
.
value
=
tf
.
Variable
(
0.0
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
_
=
self
.
dense
(
tf
.
zeros
((
2
,
2
),
tf
.
float32
))
@
tf
.
function
(
input_signature
=
[])
def
__call__
(
self
):
return
self
.
value
def
call
(
self
,
x
,
training
=
None
):
return
self
.
value
+
x
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
...
...
@@ -43,7 +45,7 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy
_gpu
,
strategy_combinations
.
one_device_strategy
,
],))
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
...
...
@@ -62,18 +64,25 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
model
.
value
.
assign
(
3
)
# Checks model.value is 3
self
.
assertEqual
(
model
(),
3
)
self
.
assertEqual
(
model
(
0.
),
3
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
({})
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
self
.
assertEqual
(
model
(
0.
),
0
)
# Raises an error for a normal optimizer.
with
self
.
assertRaisesRegex
(
ValueError
,
'Optimizer has to be instance of.*'
):
_
=
actions
.
EMACheckpointing
(
directory
,
tf
.
keras
.
optimizers
.
SGD
(),
checkpoint
)
@
combinations
.
generate
(
combinations
.
combine
(
...
...
@@ -102,6 +111,21 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
one_device_strategy
,
],))
def
test_pruning
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
get_temp_dir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
pruning
=
actions
.
PruningAction
(
directory
,
model
,
optimizer
)
pruning
({})
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment