Commit ea646b04 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322707701
parent 7fcd7cba
...@@ -17,11 +17,33 @@ ...@@ -17,11 +17,33 @@
import abc import abc
from typing import Any, Dict, Optional, Text from typing import Any, Dict, Optional, Text
import dataclasses
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
import tensorflow as tf import tensorflow as tf
@dataclasses.dataclass(frozen=True)
class TrainerOverrides:
"""Advanced overrides for Orbit trainers.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
use_tpu_summary_optimization: bool = False
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
...@@ -139,6 +161,17 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): ...@@ -139,6 +161,17 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
self._train_iter = None self._train_iter = None
@dataclasses.dataclass(frozen=True)
class EvaluatorOverrides:
"""Advanced overrides for Orbit evaluators.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
"""
use_tf_function: bool = False
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment