Commit 1759f3e0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 321883870
parent abd09bdb
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""
from orbit import utils
from orbit.controller import Controller
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,14 +15,8 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time
from typing import Callable, Optional, Text, Union
from absl import logging
from orbit import runner
from orbit import utils
......@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name, interval, steps_per_loop))
class Controller(object):
class Controller:
"""Class that facilitates training and evaluation of models."""
def __init__(
......@@ -396,7 +391,7 @@ class Controller(object):
return False
class StepTimer(object):
class StepTimer:
"""Utility class for measuring steps/second."""
def __init__(self, step):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,10 +15,6 @@
# ==============================================================================
"""Tests for orbit.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
from absl.testing import parameterized
......@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(ControllerTest, self).setUp()
super().setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
from typing import Dict, Optional, Text
import six
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta)
class AbstractTrainer(tf.Module):
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for training."""
@abc.abstractmethod
......@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass
@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluator(tf.Module):
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,21 +15,14 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
from typing import Any, Dict, Optional, Text
from orbit import runner
from orbit import utils
import six
import tensorflow as tf
@six.add_metaclass(abc.ABCMeta)
class StandardTrainer(runner.AbstractTrainer):
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self,
......@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self._train_iter = None
@six.add_metaclass(abc.ABCMeta)
class StandardEvaluator(runner.AbstractEvaluator):
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,18 +15,12 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc
import contextlib
import functools
import inspect
import numpy as np
import six
import tensorflow as tf
......@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if six.PY3:
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn) # pylint: disable=deprecated-method
args_names = argspec.args
if "input_context" in args_names:
......@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object):
class SummaryManager:
"""A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None):
......@@ -201,8 +193,7 @@ class SummaryManager(object):
self._summary_fn(name, tensor, step=self._global_step)
@six.add_metaclass(abc.ABCMeta)
class Trigger(object):
class Trigger(metaclass=abc.ABCMeta):
"""An abstract class representing a "trigger" for some event."""
@abc.abstractmethod
......@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self._last_trigger_value = 0
class EpochHelper(object):
class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment