Commit 1759f3e0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321883870
parent abd09bdb
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,6 +13,7 @@ ...@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Orbit package definition."""
from orbit import utils from orbit import utils
from orbit.controller import Controller from orbit.controller import Controller
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,14 +15,8 @@ ...@@ -14,14 +15,8 @@
# ============================================================================== # ==============================================================================
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time import time
from typing import Callable, Optional, Text, Union from typing import Callable, Optional, Text, Union
from absl import logging from absl import logging
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
...@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int], ...@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name, interval, steps_per_loop)) interval_name, interval, steps_per_loop))
class Controller(object): class Controller:
"""Class that facilitates training and evaluation of models.""" """Class that facilitates training and evaluation of models."""
def __init__( def __init__(
...@@ -396,7 +391,7 @@ class Controller(object): ...@@ -396,7 +391,7 @@ class Controller(object):
return False return False
class StepTimer(object): class StepTimer:
"""Utility class for measuring steps/second.""" """Utility class for measuring steps/second."""
def __init__(self, step): def __init__(self, step):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,10 +15,6 @@ ...@@ -14,10 +15,6 @@
# ============================================================================== # ==============================================================================
"""Tests for orbit.controller.""" """Tests for orbit.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
from absl import logging from absl import logging
from absl.testing import parameterized from absl.testing import parameterized
...@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer): ...@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class ControllerTest(tf.test.TestCase, parameterized.TestCase): class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(ControllerTest, self).setUp() super().setUp()
self.model_dir = self.get_temp_dir() self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self): def test_no_checkpoint(self):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,19 +15,12 @@ ...@@ -14,19 +15,12 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
from typing import Dict, Optional, Text from typing import Dict, Optional, Text
import six
import tensorflow as tf import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
class AbstractTrainer(tf.Module):
"""An abstract class defining the APIs required for training.""" """An abstract class defining the APIs required for training."""
@abc.abstractmethod @abc.abstractmethod
...@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module): ...@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass pass
@six.add_metaclass(abc.ABCMeta) class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
class AbstractEvaluator(tf.Module):
"""An abstract class defining the APIs required for evaluation.""" """An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod @abc.abstractmethod
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,21 +15,14 @@ ...@@ -14,21 +15,14 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
from typing import Any, Dict, Optional, Text from typing import Any, Dict, Optional, Text
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
import six
import tensorflow as tf import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
class StandardTrainer(runner.AbstractTrainer):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self, def __init__(self,
...@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer): ...@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self._train_iter = None self._train_iter = None
@six.add_metaclass(abc.ABCMeta) class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
class StandardEvaluator(runner.AbstractEvaluator):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True): def __init__(self, eval_dataset, use_tf_function=True):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,18 +15,12 @@ ...@@ -14,18 +15,12 @@
# ============================================================================== # ==============================================================================
"""Some layered modules/functions to help users writing custom training loop.""" """Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
import contextlib import contextlib
import functools import functools
import inspect import inspect
import numpy as np import numpy as np
import six
import tensorflow as tf import tensorflow as tf
...@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling # names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling # `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`. # `dataset_or_fn`.
if six.PY3: argspec = inspect.getfullargspec(dataset_or_fn)
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn) # pylint: disable=deprecated-method
args_names = argspec.args args_names = argspec.args
if "input_context" in args_names: if "input_context" in args_names:
...@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return strategy.experimental_distribute_datasets_from_function(dataset_fn) return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object): class SummaryManager:
"""A class manages writing summaries.""" """A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None): def __init__(self, summary_dir, summary_fn, global_step=None):
...@@ -201,8 +193,7 @@ class SummaryManager(object): ...@@ -201,8 +193,7 @@ class SummaryManager(object):
self._summary_fn(name, tensor, step=self._global_step) self._summary_fn(name, tensor, step=self._global_step)
@six.add_metaclass(abc.ABCMeta) class Trigger(metaclass=abc.ABCMeta):
class Trigger(object):
"""An abstract class representing a "trigger" for some event.""" """An abstract class representing a "trigger" for some event."""
@abc.abstractmethod @abc.abstractmethod
...@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger): ...@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self._last_trigger_value = 0 self._last_trigger_value = 0
class EpochHelper(object): class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop.""" """A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step): def __init__(self, epoch_steps, global_step):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment