Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1759f3e0
Commit
1759f3e0
authored
Jul 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 321883870
parent
abd09bdb
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
19 additions
and
47 deletions
+19
-47
orbit/__init__.py
orbit/__init__.py
+2
-0
orbit/controller.py
orbit/controller.py
+3
-8
orbit/controller_test.py
orbit/controller_test.py
+2
-5
orbit/runner.py
orbit/runner.py
+3
-10
orbit/standard_runner.py
orbit/standard_runner.py
+3
-10
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+1
-0
orbit/utils.py
orbit/utils.py
+5
-14
No files found.
orbit/__init__.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from
orbit
import
utils
from
orbit.controller
import
Controller
...
...
orbit/controller.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,14 +15,8 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
absl
import
logging
from
orbit
import
runner
from
orbit
import
utils
...
...
@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name
,
interval
,
steps_per_loop
))
class
Controller
(
object
)
:
class
Controller
:
"""Class that facilitates training and evaluation of models."""
def
__init__
(
...
...
@@ -396,7 +391,7 @@ class Controller(object):
return
False
class
StepTimer
(
object
)
:
class
StepTimer
:
"""Utility class for measuring steps/second."""
def
__init__
(
self
,
step
):
...
...
orbit/controller_test.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,10 +15,6 @@
# ==============================================================================
"""Tests for orbit.controller."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
absl
import
logging
from
absl.testing
import
parameterized
...
...
@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ControllerTest
,
self
).
setUp
()
super
().
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
...
...
orbit/runner.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
from
typing
import
Dict
,
Optional
,
Text
import
six
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractTrainer
(
tf
.
Module
):
class
AbstractTrainer
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for training."""
@
abc
.
abstractmethod
...
...
@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractEvaluator
(
tf
.
Module
):
class
AbstractEvaluator
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining the APIs required for evaluation."""
@
abc
.
abstractmethod
...
...
orbit/standard_runner.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,21 +15,14 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
from
typing
import
Any
,
Dict
,
Optional
,
Text
from
orbit
import
runner
from
orbit
import
utils
import
six
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardTrainer
(
runner
.
AbstractTrainer
):
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractTrainer APIs."""
def
__init__
(
self
,
...
...
@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self
.
_train_iter
=
None
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardEvaluator
(
runner
.
AbstractEvaluator
):
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
True
):
...
...
orbit/standard_runner_test.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
orbit/utils.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,18 +15,12 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
contextlib
import
functools
import
inspect
import
numpy
as
np
import
six
import
tensorflow
as
tf
...
...
@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if
six
.
PY3
:
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
else
:
argspec
=
inspect
.
getargspec
(
dataset_or_fn
)
# pylint: disable=deprecated-method
args_names
=
argspec
.
args
if
"input_context"
in
args_names
:
...
...
@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
class
SummaryManager
(
object
)
:
class
SummaryManager
:
"""A class manages writing summaries."""
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
...
...
@@ -201,8 +193,7 @@ class SummaryManager(object):
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
Trigger
(
object
):
class
Trigger
(
metaclass
=
abc
.
ABCMeta
):
"""An abstract class representing a "trigger" for some event."""
@
abc
.
abstractmethod
...
...
@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self
.
_last_trigger_value
=
0
class
EpochHelper
(
object
)
:
class
EpochHelper
:
"""A Helper class to handle epochs in Customized Training Loop."""
def
__init__
(
self
,
epoch_steps
,
global_step
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment