Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1759f3e0
Commit
1759f3e0
authored
Jul 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 321883870
parent
abd09bdb
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
19 additions
and
47 deletions
+19
-47
orbit/__init__.py
orbit/__init__.py
+2
-0
orbit/controller.py
orbit/controller.py
+3
-8
orbit/controller_test.py
orbit/controller_test.py
+2
-5
orbit/runner.py
orbit/runner.py
+3
-10
orbit/standard_runner.py
orbit/standard_runner.py
+3
-10
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+1
-0
orbit/utils.py
orbit/utils.py
+5
-14
No files found.
orbit/__init__.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -12,6 +13,7 @@
...
@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Orbit package definition."""
from
orbit
import
utils
from
orbit
import
utils
from
orbit.controller
import
Controller
from
orbit.controller
import
Controller
...
...
orbit/controller.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -14,14 +15,8 @@
...
@@ -14,14 +15,8 @@
# ==============================================================================
# ==============================================================================
"""A light weight utilities to train TF2 models."""
"""A light weight utilities to train TF2 models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
time
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
absl
import
logging
from
absl
import
logging
from
orbit
import
runner
from
orbit
import
runner
from
orbit
import
utils
from
orbit
import
utils
...
@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
...
@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name
,
interval
,
steps_per_loop
))
interval_name
,
interval
,
steps_per_loop
))
class
Controller
(
object
)
:
class
Controller
:
"""Class that facilitates training and evaluation of models."""
"""Class that facilitates training and evaluation of models."""
def
__init__
(
def
__init__
(
...
@@ -396,7 +391,7 @@ class Controller(object):
...
@@ -396,7 +391,7 @@ class Controller(object):
return
False
return
False
class
StepTimer
(
object
)
:
class
StepTimer
:
"""Utility class for measuring steps/second."""
"""Utility class for measuring steps/second."""
def
__init__
(
self
,
step
):
def
__init__
(
self
,
step
):
...
...
orbit/controller_test.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -14,10 +15,6 @@
...
@@ -14,10 +15,6 @@
# ==============================================================================
# ==============================================================================
"""Tests for orbit.controller."""
"""Tests for orbit.controller."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
os
from
absl
import
logging
from
absl
import
logging
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
...
@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
ControllerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
ControllerTest
,
self
).
setUp
()
super
().
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
def
test_no_checkpoint
(
self
):
...
...
orbit/runner.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -14,19 +15,12 @@
...
@@ -14,19 +15,12 @@
# ==============================================================================
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
abc
from
typing
import
Dict
,
Optional
,
Text
from
typing
import
Dict
,
Optional
,
Text
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractTrainer
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
AbstractTrainer
(
tf
.
Module
):
"""An abstract class defining the APIs required for training."""
"""An abstract class defining the APIs required for training."""
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
...
@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
AbstractEvaluator
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
class
AbstractEvaluator
(
tf
.
Module
):
"""An abstract class defining the APIs required for evaluation."""
"""An abstract class defining the APIs required for evaluation."""
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
...
orbit/standard_runner.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -14,21 +15,14 @@
...
@@ -14,21 +15,14 @@
# ==============================================================================
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
"""An abstraction that users can easily handle their custom training loops."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
abc
from
typing
import
Any
,
Dict
,
Optional
,
Text
from
typing
import
Any
,
Dict
,
Optional
,
Text
from
orbit
import
runner
from
orbit
import
runner
from
orbit
import
utils
from
orbit
import
utils
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
class
StandardTrainer
(
runner
.
AbstractTrainer
):
"""Implements the standard functionality of AbstractTrainer APIs."""
"""Implements the standard functionality of AbstractTrainer APIs."""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
...
@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self
.
_train_iter
=
None
self
.
_train_iter
=
None
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
class
StandardEvaluator
(
runner
.
AbstractEvaluator
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
True
):
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
True
):
...
...
orbit/standard_runner_test.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
orbit/utils.py
View file @
1759f3e0
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -14,18 +15,12 @@
...
@@ -14,18 +15,12 @@
# ==============================================================================
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
"""Some layered modules/functions to help users writing custom training loop."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
abc
import
abc
import
contextlib
import
contextlib
import
functools
import
functools
import
inspect
import
inspect
import
numpy
as
np
import
numpy
as
np
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
...
@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
# `dataset_or_fn`.
if
six
.
PY3
:
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
argspec
=
inspect
.
getfullargspec
(
dataset_or_fn
)
else
:
argspec
=
inspect
.
getargspec
(
dataset_or_fn
)
# pylint: disable=deprecated-method
args_names
=
argspec
.
args
args_names
=
argspec
.
args
if
"input_context"
in
args_names
:
if
"input_context"
in
args_names
:
...
@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
...
@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
return
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
class
SummaryManager
(
object
)
:
class
SummaryManager
:
"""A class manages writing summaries."""
"""A class manages writing summaries."""
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
...
@@ -201,8 +193,7 @@ class SummaryManager(object):
...
@@ -201,8 +193,7 @@ class SummaryManager(object):
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
self
.
_summary_fn
(
name
,
tensor
,
step
=
self
.
_global_step
)
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
Trigger
(
metaclass
=
abc
.
ABCMeta
):
class
Trigger
(
object
):
"""An abstract class representing a "trigger" for some event."""
"""An abstract class representing a "trigger" for some event."""
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
...
@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self
.
_last_trigger_value
=
0
self
.
_last_trigger_value
=
0
class
EpochHelper
(
object
)
:
class
EpochHelper
:
"""A Helper class to handle epochs in Customized Training Loop."""
"""A Helper class to handle epochs in Customized Training Loop."""
def
__init__
(
self
,
epoch_steps
,
global_step
):
def
__init__
(
self
,
epoch_steps
,
global_step
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment