Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
134c9508
"src/vscode:/vscode.git/clone" did not exist on "0cfbb51b0c36f541697e5ef83296bda874ac0671"
Commit
134c9508
authored
Jul 27, 2020
by
Simon Kornblith
Committed by
A. Unique TensorFlower
Jul 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 323493207
parent
8b8524c7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
5 deletions
+11
-5
orbit/controller.py
orbit/controller.py
+8
-3
orbit/controller_test.py
orbit/controller_test.py
+2
-1
orbit/utils.py
orbit/utils.py
+1
-1
No files found.
orbit/controller.py
View file @
134c9508
...
@@ -16,8 +16,9 @@
...
@@ -16,8 +16,9 @@
"""A light weight utilities to train TF2 models."""
"""A light weight utilities to train TF2 models."""
import
time
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Text
,
Union
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
from
orbit
import
runner
from
orbit
import
runner
from
orbit
import
utils
from
orbit
import
utils
...
@@ -177,7 +178,7 @@ class Controller:
...
@@ -177,7 +178,7 @@ class Controller:
if
checkpoint_at_completion
:
if
checkpoint_at_completion
:
self
.
save_checkpoint
()
self
.
save_checkpoint
()
def
evaluate
(
self
,
steps
:
int
=
None
):
def
evaluate
(
self
,
steps
:
int
=
None
)
->
Optional
[
Dict
[
Text
,
np
.
number
]]
:
"""Runs evaluation.
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
This method calls the `evaluate` method on the Evaluator object for `steps`
...
@@ -186,10 +187,12 @@ class Controller:
...
@@ -186,10 +187,12 @@ class Controller:
Args:
Args:
steps: The number of steps to evaluate for.
steps: The number of steps to evaluate for.
Returns:
The evaluation results as a dictionary of numpy values.
Raises:
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
ValueError: If `evaluator` is not provided.
"""
"""
if
self
.
evaluator
is
None
:
if
self
.
evaluator
is
None
:
raise
ValueError
(
"`evaluator` must be provided to call `evaluate()` "
raise
ValueError
(
"`evaluator` must be provided to call `evaluate()` "
...
@@ -217,6 +220,8 @@ class Controller:
...
@@ -217,6 +220,8 @@ class Controller:
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
flush
()
self
.
eval_summary_manager
.
flush
()
return
eval_outputs
def
restore_checkpoint
(
self
,
checkpoint_path
:
Text
=
None
):
def
restore_checkpoint
(
self
,
checkpoint_path
:
Text
=
None
):
"""Restore or initialize the model.
"""Restore or initialize the model.
...
...
orbit/controller_test.py
View file @
134c9508
...
@@ -329,7 +329,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -329,7 +329,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager
=
checkpoint_manager
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
evaluate
(
steps
=
2
)
eval_results
=
test_controller
.
evaluate
(
steps
=
2
)
# Only eval summaries are written
# Only eval summaries are written
self
.
assertFalse
(
self
.
assertFalse
(
...
@@ -339,6 +339,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -339,6 +339,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertNotEmpty
(
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
summaries_with_matching_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertIn
(
"eval_loss"
,
eval_results
)
# Tests continuous eval with timeout and timeout_fn.
# Tests continuous eval with timeout and timeout_fn.
done_file
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval/Done"
)
done_file
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval/Done"
)
...
...
orbit/utils.py
View file @
134c9508
...
@@ -378,7 +378,7 @@ def get_value(x) -> np.ndarray:
...
@@ -378,7 +378,7 @@ def get_value(x) -> np.ndarray:
x: input variable.
x: input variable.
Returns:
Returns:
A Numpy array.
A Numpy array
or number
.
"""
"""
if
not
tf
.
is_tensor
(
x
):
if
not
tf
.
is_tensor
(
x
):
return
x
return
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment